From 568fcf59f023bb5f16f98d020a1e8147d6d0ac2e Mon Sep 17 00:00:00 2001 From: Joseph Pagadora Date: Thu, 26 Jun 2025 14:15:27 -0700 Subject: [PATCH] feat: Propagate LLM agent tools to populate function API spec during agent evaluation PiperOrigin-RevId: 776278342 --- src/google/adk/evaluation/constants.py | 2 + src/google/adk/evaluation/eval_case.py | 19 +- src/google/adk/evaluation/eval_metrics.py | 21 ++ .../adk/evaluation/evaluation_generator.py | 13 + .../adk/evaluation/final_response_match_v2.py | 173 ++++++++++ src/google/adk/evaluation/llm_as_judge.py | 136 ++++++++ .../test_final_response_match_v2.py | 307 ++++++++++++++++++ .../unittests/evaluation/test_llm_as_judge.py | 208 ++++++++++++ 8 files changed, 876 insertions(+), 3 deletions(-) create mode 100644 src/google/adk/evaluation/final_response_match_v2.py create mode 100644 src/google/adk/evaluation/llm_as_judge.py create mode 100644 tests/unittests/evaluation/test_final_response_match_v2.py create mode 100644 tests/unittests/evaluation/test_llm_as_judge.py diff --git a/src/google/adk/evaluation/constants.py b/src/google/adk/evaluation/constants.py index 74248ed18..94cddc158 100644 --- a/src/google/adk/evaluation/constants.py +++ b/src/google/adk/evaluation/constants.py @@ -18,3 +18,5 @@ "Eval module is not installed, please install via `pip install" " google-adk[eval]`." ) + +DEFAULT_JUDGE_MODEL = "gemini-2.5-flash" diff --git a/src/google/adk/evaluation/eval_case.py b/src/google/adk/evaluation/eval_case.py index 172a8309d..91fc1a3ef 100644 --- a/src/google/adk/evaluation/eval_case.py +++ b/src/google/adk/evaluation/eval_case.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from typing import Any from typing import Optional -from typing import Tuple from google.genai import types as genai_types from pydantic import alias_generators @@ -31,17 +31,27 @@ class EvalBaseModel(BaseModel): ) +class FunctionSpec(EvalBaseModel): + """Function specification.""" + + name: str + """The name of the function.""" + + description: str + """The description of the function, including parameters and return type.""" + + class IntermediateData(EvalBaseModel): """Container for intermediate data that an agent would generate as it responds with a final answer.""" tool_uses: list[genai_types.FunctionCall] = [] """Tool use trajectory in chronological order.""" - intermediate_responses: list[Tuple[str, list[genai_types.Part]]] = [] + intermediate_responses: list[tuple[str, list[genai_types.Part]]] = [] """Intermediate responses generated by sub-agents to convey progress or status in a multi-agent system, distinct from the final response. - This is expressed as a Tuple of: + This is expressed as a tuple of: - Author: Usually the sub-agent name that generated the intermediate response. @@ -71,6 +81,9 @@ class Invocation(EvalBaseModel): creation_timestamp: float = 0.0 """Timestamp for the current invocation, primarily intended for debugging purposes.""" + function_api_spec: list[FunctionSpec] = Field(default_factory=list) + """Function API spec for the invocation.""" + class SessionInput(EvalBaseModel): """Values that help initialize a Session.""" diff --git a/src/google/adk/evaluation/eval_metrics.py b/src/google/adk/evaluation/eval_metrics.py index 91ef6e6f6..a82fb2d79 100644 --- a/src/google/adk/evaluation/eval_metrics.py +++ b/src/google/adk/evaluation/eval_metrics.py @@ -16,14 +16,30 @@ from typing import Optional +from google.genai import types as genai_types from pydantic import alias_generators from pydantic import BaseModel from pydantic import ConfigDict +from pydantic import Field from .eval_case import Invocation from .evaluator import EvalStatus +class JudgeModelOptions(BaseModel): + """Options for an eval metric's judge model.""" + + judge_model: str + """The judge model to use for evaluation. It can be a model name.""" + + judge_model_config: Optional[genai_types.GenerateContentConfig] = Field( + default=None, description="""The configuration for the judge model.""" + ) + + num_samples: int = 1 + """The number of samples to generate for each invocation.""" + + class EvalMetric(BaseModel): """A metric used to evaluate a particular aspect of an eval case.""" @@ -38,6 +54,11 @@ class EvalMetric(BaseModel): threshold: float """A threshold value. Each metric decides how to interpret this threshold.""" + judge_model_options: Optional[JudgeModelOptions] = Field( + default=None, + description="""Options for the judge model, if applicable.""", + ) + class EvalMetricResult(EvalMetric): """The actual computed score/value of a particular EvalMetric.""" diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index 1359967bc..22fc37acf 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -29,6 +29,7 @@ from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session from .eval_case import EvalCase +from .eval_case import FunctionSpec from .eval_case import IntermediateData from .eval_case import Invocation from .eval_case import SessionInput @@ -153,6 +154,17 @@ async def _generate_inferences_from_root_agent( user_id = initial_session.user_id if initial_session else "test_user_id" session_id = session_id if session_id else str(uuid.uuid4()) + function_specs = [] + if hasattr(root_agent, "canonical_tools"): + tools = await root_agent.canonical_tools() + for tool in tools: + function_specs.append( + FunctionSpec( + name=tool.name, + description=tool.description, + ) + ) + _ = await session_service.create_session( app_name=app_name, user_id=user_id, @@ -201,6 +213,7 @@ async def _generate_inferences_from_root_agent( user_content=user_content, final_response=final_response, intermediate_data=IntermediateData(tool_uses=tool_uses), + function_api_spec=function_specs, ) ) diff --git a/src/google/adk/evaluation/final_response_match_v2.py b/src/google/adk/evaluation/final_response_match_v2.py new file mode 100644 index 000000000..8f4bbb002 --- /dev/null +++ b/src/google/adk/evaluation/final_response_match_v2.py @@ -0,0 +1,173 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import logging +import re +from typing import Optional + +from typing_extensions import override + +from ..models.llm_response import LlmResponse +from .eval_case import FunctionSpec +from .eval_case import Invocation +from .eval_metrics import EvalMetric +from .evaluator import EvalStatus +from .evaluator import EvaluationResult +from .evaluator import PerInvocationResult +from .llm_as_judge import get_eval_status +from .llm_as_judge import get_text_from_content +from .llm_as_judge import LlmAsJudge + +logger = logging.getLogger("google_adk." + __name__) + +FINAL_RESPONSE_MATCH_V2_PROMPT = """ +You are an expert rater for an AI agent. The AI agent is going to call an API to answer the user query and generate API tool use code based for the choice of the API and API arguments. The ideal model response should be a function call that fulfills user query, or a natural language response hedges or asks users for further clarification if a function call does not apply. +The primary focus of this rating task is to check correctness of the model responses. + +The data consists of: +- A set of python function definitions available to the agent. +- A user query. +- A model generated response for the prompt. The responses can consist of: + - Natural language, when the model is asking for clarification, or tells the user it does not possess the requested functionality / option. + - Code, in the form of one or multiple python function calls, and additional code as needed, for when the model is fulfilling the user request. +You can use the help from a reference response annotated by a human rater. This reference response is of high quality. You can compare the agent's response with the reference response and decide if the agent's response is valid. +Note sometimes the reference response only contains the key entities of the correct answer and you need to be flexible to allow the agent response to contain more information than the reference response, or to present the key entities in a different format or structure or in shorter or longer format. +When the agent response is provided in the form of tables/dataframes or should be best provided in the form of tables/dataframes: focus on the key entities and main components requested in the user query and check whether you can retrieve those from the agent response. Likewise, if you have the reference response, then find out the key entities and main components in them and check whether you can retrieve those from the agent response. If the prompt does not specify any format instructions and the main items/components are included in the response then tolerate the differences in the formatting of those tables/dataframes. + +You should follow the constitutions below very carefully to rate the model response: +- Allow flexibility of format even when reference code only uses one of the possible format, unless API spec or user prompt has explicit format requirement + - e.g. For state name, allow both abbreviation and full name unless API spec has explicit requirement. e.g. both 'tx' and 'Texas' should be allowed in the agent response even when reference code only uses one of them. + - e.g. If a reference response list outputs in a list format, the agent response is allowed to use sentence format and vice versa unless user prompt explicitly asks for a specific format. + - e.g. For numbers, allow flexibility of formatting, e.g. 1000000 vs 1,000,000. +- The model shouldn't assume that it doesn't have access to according data or incapable of answering the question if reference response is able to find a legit answer. +- If the model response contains the correct final answer, rate it as valid even when the model response contains more information than the reference response. +- If the user prompt has csv or other table format data, don't read it yourself. Trust the reference response final answer instead. +- When the validation needs maths, date calculations, do not use your own calculator. Trust the reference response final answer instead. +- Be mindful about unit of numbers. For example, if the reference response says 100 miles, but the model response says 100 km, it is invalid. +- When the agent response or the reference response is provided in the form of tables/dataframes: focus on the key entities and main components requested in the user query and check whether you can retrieve those from the agent response and whether those match the reference response. If the user query does not specify any format instructions and the main items/components are included in the response then tolerate the differences in the formatting of those tables/dataframes. +- When the answer is in numeric format, check whether there are any format requirements in the numeric format, rounding, precision, number of decimals, etc. specified in the user query and the prompt. If there are no such instructions, then tolerate different numerical formats. +- When the answer is in numeric format and there are rounding or precision differences between the agent response and the reference response, if no further instructions are provided evaluate if the rounding strategy or precision in the agent response follows the standards for that entity. For instance, model accuracy scores must be reported with at least two decimal places (e.g., 0.798 → 0.80 is acceptable, but 0.7 is not). + +Below are the inputs: +{{ + "Function API spec": {function_api_spec}, + "User prompt": {prompt}, + "Agent response": {response}, + "Reference response": {golden_response}, +}} + +The answer should be a json alone which follows the json structure below: +{{ + "is_the_agent_response_valid": [valid or invalid], + "reasoning": +}} +Answer with assertiveness: +""" + + +def _extract_validity_label(response_text: str) -> Optional[str]: + label_match_is_response_valid = re.search( + r'"is_the_agent_response_valid":\s*\[*[\n\s]*"*([^"^\]^\s]*)"*[\n\s]*\]*\s*[,\n\}]', + response_text, + ) + label_match_is_response_invalid = re.search( + r'"is_the_agent_response_invalid":\s*\[*[\n\s]*"*([^"^\]^\s]*)"*[\n\s]*\]*\s*[,\n\}]', + response_text, + ) + if label_match_is_response_valid: + return label_match_is_response_valid.group(1) + elif label_match_is_response_invalid: + return label_match_is_response_invalid.group(1) + else: + return None + + +def _format_function_api_spec(functions: list[FunctionSpec]) -> str: + function_api_spec = [] + for function in functions: + function_spec = { + "Function name": function.name, + "Function description": function.description, + } + function_api_spec.append(function_spec) + return json.dumps(function_api_spec) + + +class FinalResponseMatchV2Evaluator(LlmAsJudge): + """V2 final response match evaluator which uses an LLM to judge responses.""" + + def __init__( + self, + eval_metric: EvalMetric, + ): + super().__init__(eval_metric) + self._auto_rater_prompt_template = FINAL_RESPONSE_MATCH_V2_PROMPT + + @override + def format_auto_rater_prompt( + self, actual_invocation: Invocation, expected_invocation: Invocation + ) -> str: + reference = get_text_from_content(expected_invocation.final_response) + response = get_text_from_content(actual_invocation.final_response) + user_prompt = get_text_from_content(expected_invocation.user_content) + function_api_spec = _format_function_api_spec( + actual_invocation.function_api_spec + ) + return self._auto_rater_prompt_template.format( + function_api_spec=function_api_spec, + prompt=user_prompt, + response=response, + golden_response=reference, + ) + + @override + def convert_auto_rater_response_to_score( + self, llm_response: LlmResponse + ) -> Optional[float]: + try: + response_text = get_text_from_content(llm_response.content).strip() + label = _extract_validity_label(response_text) + except json.JSONDecodeError: + logger.error("Failed to parse auto rater response: %s", llm_response) + return None + if label == "valid": + return 1.0 + elif label == "invalid": + return 0.0 + else: + return None + + @override + def aggregate_invocation_results( + self, per_invocation_results: list[PerInvocationResult] + ) -> EvaluationResult: + """Computes the fraction of invocation results that are valid.""" + num_valid = 0 + num_evaluated = 0 + for result in per_invocation_results: + if result.score is None or result.eval_status == EvalStatus.NOT_EVALUATED: + continue + num_evaluated += 1 + num_valid += result.score + overall_score = num_valid / num_evaluated + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=get_eval_status( + overall_score, self._eval_metric.threshold + ), + per_invocation_results=per_invocation_results, + ) diff --git a/src/google/adk/evaluation/llm_as_judge.py b/src/google/adk/evaluation/llm_as_judge.py new file mode 100644 index 000000000..81a841073 --- /dev/null +++ b/src/google/adk/evaluation/llm_as_judge.py @@ -0,0 +1,136 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import Optional + +from google.genai import types as genai_types +from typing_extensions import override + +from ..models.base_llm import BaseLlm +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from ..models.registry import LLMRegistry +from .eval_case import Invocation +from .eval_metrics import EvalMetric +from .evaluator import EvalStatus +from .evaluator import EvaluationResult +from .evaluator import Evaluator +from .evaluator import PerInvocationResult + + +def get_text_from_content(content: Optional[genai_types.Content]) -> str: + if content and content.parts: + return "\n".join([p.text for p in content.parts if p.text]) + + +def get_eval_status(score: Optional[float], threshold: float) -> EvalStatus: + if score is None: + return EvalStatus.NOT_EVALUATED + return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED + + +class LlmAsJudge(Evaluator): + """Response evaluator based on an auto-rater (LLM). + + It is meant to be extended by specific auto-raters for different evaluation + tasks: + - Provide the prompt template, and implement format_auto_rater_prompt to + format the auto-rater prompt for a given invocation. + - Implement convert_auto_rater_response_to_score to parse the auto-rater + response and return the corresponding score. + - Implement aggregate_invocation_results to aggregate the per-invocation + results to get the overall score. + """ + + def __init__( + self, + eval_metric: EvalMetric, + ): + self._eval_metric = eval_metric + if not eval_metric.judge_model_options: + raise ValueError("Judge model options is required for LlmAsJudge.") + self._judge_model_options = eval_metric.judge_model_options + if self._judge_model_options.judge_model_config is None: + self._judge_model_options.judge_model_config = ( + genai_types.GenerateContentConfig() + ) + self._judge_model = self._setup_auto_rater() + + @abstractmethod + def format_auto_rater_prompt( + self, actual: Invocation, expected: Invocation + ) -> str: + """Formats the auto-rater prompt to evaluate the given invocation.""" + + @abstractmethod + def convert_auto_rater_response_to_score( + self, auto_rater_response: LlmResponse + ) -> Optional[float]: + """Parses auto_rater_response and returns the corresponding score, or None if the score cannot be determined.""" + + @abstractmethod + def aggregate_invocation_results( + self, + per_invocation_results: list[PerInvocationResult], + ) -> EvaluationResult: + """Aggregates the per invocation results to get the overall score.""" + + @override + async def evaluate_invocations( + self, + actual_invocations: list[Invocation], + expected_invocations: list[Invocation], + ) -> EvaluationResult: + per_invocation_results = [] + for actual, expected in zip(actual_invocations, expected_invocations): + auto_rater_prompt = self.format_auto_rater_prompt(actual, expected) + llm_request = LlmRequest( + model=self._judge_model_options.judge_model, + contents=[ + genai_types.Content( + parts=[genai_types.Part(text=auto_rater_prompt)], + role="user", + ) + ], + config=self._judge_model_options.judge_model_config, + ) + for _ in range(self._judge_model_options.num_samples): + async for llm_response in self._judge_model.generate_content_async( + llm_request + ): + # Non-streaming call, so there is only one response content. + score = self.convert_auto_rater_response_to_score(llm_response) + per_invocation_results.append( + PerInvocationResult( + actual_invocation=actual, + expected_invocation=expected, + score=score, + eval_status=get_eval_status( + score, self._eval_metric.threshold + ), + ) + ) + + if per_invocation_results: + return self.aggregate_invocation_results(per_invocation_results) + return EvaluationResult() + + def _setup_auto_rater(self) -> BaseLlm: + model_id = self._judge_model_options.judge_model + llm_registry = LLMRegistry() + llm_class = llm_registry.resolve(model_id) + return llm_class(model=model_id) diff --git a/tests/unittests/evaluation/test_final_response_match_v2.py b/tests/unittests/evaluation/test_final_response_match_v2.py new file mode 100644 index 000000000..5b143caa6 --- /dev/null +++ b/tests/unittests/evaluation/test_final_response_match_v2.py @@ -0,0 +1,307 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.evaluation.eval_case import FunctionSpec +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_metrics import EvalMetric +from google.adk.evaluation.eval_metrics import JudgeModelOptions +from google.adk.evaluation.evaluator import EvalStatus +from google.adk.evaluation.evaluator import PerInvocationResult +from google.adk.evaluation.final_response_match_v2 import _extract_validity_label +from google.adk.evaluation.final_response_match_v2 import FinalResponseMatchV2Evaluator +from google.adk.models.llm_response import LlmResponse +from google.genai import types as genai_types +import pytest + + +def test_extract_validity_label_missing_key(): + response_text = """```json + { + "is_the_agent_response_valid_or_invalid": "valid", + "reasoning": "The response is valid." + } + ```""" + label = _extract_validity_label(response_text) + assert label is None + + +@pytest.mark.parametrize( + "response_text", + [ + """```json + { + "is_the_agent_response_valid": "valid", + "reasoning": "The response is valid." + } + ```""", + """```json + { + "is_the_agent_response_valid": ["valid"], + "reasoning": "The response is valid." + } + ```""", + """```json + { + "is_the_agent_response_valid":\n [ "valid\n"], + "reasoning": "The response is valid." + } + ```""", + ], +) +def test_extract_validity_label(response_text): + label = _extract_validity_label(response_text) + assert label == "valid" + + +@pytest.mark.parametrize( + "response_text", + [ + """```json + { + "is_the_agent_response_invalid": "invalid", + "reasoning": "The response is invalid." + } + ```""", + """```json + { + "is_the_agent_response_invalid": ["invalid"], + "reasoning": "The response is invalid." + } + ```""", + """```json + { + "is_the_agent_response_invalid":\n [ "invalid\n"], + "reasoning": "The response is invalid." + } + ```""", + ], +) +def test_extract_validity_label_invalid(response_text): + label = _extract_validity_label(response_text) + assert label == "invalid" + + +def create_test_template() -> str: + return """ +This is a test template. + +{{ + "Function API spec": {function_api_spec}, + "User prompt": {prompt}, + "Agent response": {response}, + "Reference response": {golden_response}, +}} + +The answer should be a json alone which follows the json structure below: +{{ + "is_the_agent_response_valid": [valid or invalid], + "reasoning": +}} +""" + + +def _create_test_evaluator_gemini( + threshold: float, +) -> FinalResponseMatchV2Evaluator: + evaluator = FinalResponseMatchV2Evaluator( + EvalMetric( + metric_name="final_response_match_v2", + threshold=threshold, + judge_model_options=JudgeModelOptions( + judge_model="gemini-2.5-flash", + num_samples=3, + ), + ), + ) + evaluator._auto_rater_prompt_template = create_test_template() + return evaluator + + +def _create_test_invocations( + candidate: str, reference: str +) -> tuple[Invocation, Invocation]: + """Returns tuple of (actual_invocation, expected_invocation).""" + actual_invocation = Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")], + role="user", + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text=candidate)], + role="model", + ), + function_api_spec=[ + FunctionSpec( + name="test_tool", + description="description.", + ), + ], + ) + expected_invocation = Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")], + role="user", + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text=reference)], + role="model", + ), + ) + return actual_invocation, expected_invocation + + +def test_format_auto_rater_prompt(): + evaluator = _create_test_evaluator_gemini(threshold=0.8) + actual_invocation, expected_invocation = _create_test_invocations( + "candidate text", "reference text" + ) + prompt = evaluator.format_auto_rater_prompt( + actual_invocation, expected_invocation + ) + assert prompt == """ +This is a test template. + +{ + "Function API spec": [{"Function name": "test_tool", "Function description": "description."}], + "User prompt": This is a test query., + "Agent response": candidate text, + "Reference response": reference text, +} + +The answer should be a json alone which follows the json structure below: +{ + "is_the_agent_response_valid": [valid or invalid], + "reasoning": +} +""" + + +def test_convert_auto_rater_response_to_score_valid(): + evaluator = _create_test_evaluator_gemini(threshold=0.8) + auto_rater_response = """```json +{ + "is_the_agent_response_valid": "valid", + "reasoning": "The response is valid." +} +```""" + llm_response = LlmResponse( + content=genai_types.Content( + parts=[genai_types.Part(text=auto_rater_response)], + role="model", + ) + ) + score = evaluator.convert_auto_rater_response_to_score(llm_response) + assert score == 1.0 + + +def test_convert_auto_rater_response_to_score_invalid(): + evaluator = _create_test_evaluator_gemini(threshold=0.8) + auto_rater_response = """```json +{ + "is_the_agent_response_valid": "invalid", + "reasoning": "The response is invalid." +} +```""" + llm_response = LlmResponse( + content=genai_types.Content( + parts=[genai_types.Part(text=auto_rater_response)], + role="model", + ) + ) + score = evaluator.convert_auto_rater_response_to_score(llm_response) + assert score == 0.0 + + +def test_convert_auto_rater_response_to_score_invalid_json(): + evaluator = _create_test_evaluator_gemini(threshold=0.8) + llm_response = LlmResponse( + content=genai_types.Content( + parts=[genai_types.Part(text="invalid json")], + role="model", + ) + ) + score = evaluator.convert_auto_rater_response_to_score(llm_response) + assert score is None + + +def test_convert_auto_rater_response_to_score_missing_key(): + evaluator = _create_test_evaluator_gemini(threshold=0.8) + llm_response = LlmResponse( + content=genai_types.Content( + parts=[genai_types.Part(text="{}")], + role="model", + ) + ) + score = evaluator.convert_auto_rater_response_to_score(llm_response) + assert score is None + + +def test_aggregate_invocation_results(): + evaluator = _create_test_evaluator_gemini(threshold=0.5) + actual_invocation, expected_invocation = _create_test_invocations( + "candidate text", "reference text" + ) + + per_invocation_results = [ + PerInvocationResult( + actual_invocation=actual_invocation, + expected_invocation=expected_invocation, + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=actual_invocation, + expected_invocation=expected_invocation, + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=actual_invocation, + expected_invocation=expected_invocation, + score=0.0, + eval_status=EvalStatus.FAILED, + ), + PerInvocationResult( + actual_invocation=actual_invocation, + expected_invocation=expected_invocation, + score=0.0, + eval_status=EvalStatus.FAILED, + ), + PerInvocationResult( + actual_invocation=actual_invocation, + expected_invocation=expected_invocation, + score=None, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=actual_invocation, + expected_invocation=expected_invocation, + score=100.0, + eval_status=EvalStatus.NOT_EVALUATED, + ), + PerInvocationResult( + actual_invocation=actual_invocation, + expected_invocation=expected_invocation, + score=None, + eval_status=EvalStatus.NOT_EVALUATED, + ), + ] + + aggregated_result = evaluator.aggregate_invocation_results( + per_invocation_results + ) + assert aggregated_result.overall_score == 0.5 + assert aggregated_result.overall_eval_status == EvalStatus.PASSED diff --git a/tests/unittests/evaluation/test_llm_as_judge.py b/tests/unittests/evaluation/test_llm_as_judge.py new file mode 100644 index 000000000..b95385628 --- /dev/null +++ b/tests/unittests/evaluation/test_llm_as_judge.py @@ -0,0 +1,208 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional +from unittest.mock import MagicMock + +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_metrics import EvalMetric +from google.adk.evaluation.eval_metrics import JudgeModelOptions +from google.adk.evaluation.evaluator import EvalStatus +from google.adk.evaluation.evaluator import EvaluationResult +from google.adk.evaluation.evaluator import PerInvocationResult +from google.adk.evaluation.llm_as_judge import get_eval_status +from google.adk.evaluation.llm_as_judge import get_text_from_content +from google.adk.evaluation.llm_as_judge import LlmAsJudge +from google.adk.models.llm_response import LlmResponse +from google.genai import types as genai_types +import pytest + + +class MockLlmAsJudge(LlmAsJudge): + + def format_auto_rater_prompt( + self, actual_invocation: Invocation, expected_invocation: Invocation + ) -> str: + return "formatted prompt" + + def convert_auto_rater_response_to_score( + self, llm_response: LlmResponse + ) -> Optional[float]: + return 1.0 + + def aggregate_invocation_results( + self, per_invocation_results: list[PerInvocationResult] + ) -> EvaluationResult: + return EvaluationResult( + overall_score=1.0, overall_eval_status=EvalStatus.PASSED + ) + + +@pytest.fixture +def mock_llm_as_judge(): + return MockLlmAsJudge( + EvalMetric( + metric_name="test_metric", + threshold=0.5, + judge_model_options=JudgeModelOptions( + judge_model="gemini-2.5-flash", + judge_model_config=genai_types.GenerateContentConfig(), + num_samples=3, + ), + ), + ) + + +def test_get_text_from_content(): + content = genai_types.Content( + parts=[ + genai_types.Part(text="This is a test text."), + genai_types.Part(text="This is another test text."), + ], + role="model", + ) + assert ( + get_text_from_content(content) + == "This is a test text.\nThis is another test text." + ) + + +def test_get_eval_status(): + assert get_eval_status(score=0.8, threshold=0.8) == EvalStatus.PASSED + assert get_eval_status(score=0.7, threshold=0.8) == EvalStatus.FAILED + assert get_eval_status(score=0.8, threshold=0.9) == EvalStatus.FAILED + assert get_eval_status(score=0.9, threshold=0.8) == EvalStatus.PASSED + assert get_eval_status(score=None, threshold=0.8) == EvalStatus.NOT_EVALUATED + + +def test_llm_as_judge_init_missing_judge_model_options(): + with pytest.raises(ValueError): + MockLlmAsJudge( + EvalMetric(metric_name="test_metric", threshold=0.8), + ) + + +def test_llm_as_judge_init_unregistered_model(): + with pytest.raises(ValueError): + MockLlmAsJudge( + EvalMetric( + metric_name="test_metric", + threshold=0.8, + judge_model_options=JudgeModelOptions( + judge_model="unregistered_model", + ), + ), + ) + + +@pytest.fixture +def mock_judge_model(): + mock_judge_model = MagicMock() + + async def mock_generate_content_async(llm_request): + yield LlmResponse( + content=genai_types.Content( + parts=[genai_types.Part(text="auto rater response")], + ) + ) + + mock_judge_model.generate_content_async = mock_generate_content_async + return mock_judge_model + + +@pytest.mark.asyncio +async def test_evaluate_invocations_with_mock( + mock_llm_as_judge, mock_judge_model +): + mock_llm_as_judge._judge_model = mock_judge_model + + mock_format_auto_rater_prompt = MagicMock( + wraps=mock_llm_as_judge.format_auto_rater_prompt + ) + mock_llm_as_judge.format_auto_rater_prompt = mock_format_auto_rater_prompt + + mock_convert_auto_rater_response_to_score = MagicMock( + wraps=mock_llm_as_judge.convert_auto_rater_response_to_score + ) + mock_llm_as_judge.convert_auto_rater_response_to_score = ( + mock_convert_auto_rater_response_to_score + ) + + mock_aggregate_invocation_results = MagicMock( + wraps=mock_llm_as_judge.aggregate_invocation_results + ) + mock_llm_as_judge.aggregate_invocation_results = ( + mock_aggregate_invocation_results + ) + + actual_invocations = [ + Invocation( + invocation_id="id1", + user_content=genai_types.Content( + parts=[genai_types.Part(text="user content 1")], + role="user", + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text="final response 1")], + role="model", + ), + ), + Invocation( + invocation_id="id2", + user_content=genai_types.Content( + parts=[genai_types.Part(text="user content 2")], + role="user", + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text="final response 2")], + role="model", + ), + ), + ] + expected_invocations = [ + Invocation( + invocation_id="id1", + user_content=genai_types.Content( + parts=[genai_types.Part(text="user content 1")], + role="user", + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text="expected response 1")], + role="model", + ), + ), + Invocation( + invocation_id="id2", + user_content=genai_types.Content( + parts=[genai_types.Part(text="user content 2")], + role="user", + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text="expected response 2")], + role="model", + ), + ), + ] + + result = await mock_llm_as_judge.evaluate_invocations( + actual_invocations, expected_invocations + ) + + # Assertions + assert result.overall_score == 1.0 + assert mock_llm_as_judge.format_auto_rater_prompt.call_count == 2 + assert mock_llm_as_judge.convert_auto_rater_response_to_score.call_count == 6 + assert mock_llm_as_judge.aggregate_invocation_results.call_count == 1