Skip to content

feat: Propagate LLM agent tools to populate function API spec during agent evaluation #1681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/google/adk/evaluation/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
"Eval module is not installed, please install via `pip install"
" google-adk[eval]`."
)

DEFAULT_JUDGE_MODEL = "gemini-2.5-flash"
19 changes: 16 additions & 3 deletions src/google/adk/evaluation/eval_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any
from typing import Optional
from typing import Tuple

from google.genai import types as genai_types
from pydantic import alias_generators
Expand All @@ -31,17 +31,27 @@ class EvalBaseModel(BaseModel):
)


class FunctionSpec(EvalBaseModel):
"""Function specification."""

name: str
"""The name of the function."""

description: str
"""The description of the function, including parameters and return type."""


class IntermediateData(EvalBaseModel):
"""Container for intermediate data that an agent would generate as it responds with a final answer."""

tool_uses: list[genai_types.FunctionCall] = []
"""Tool use trajectory in chronological order."""

intermediate_responses: list[Tuple[str, list[genai_types.Part]]] = []
intermediate_responses: list[tuple[str, list[genai_types.Part]]] = []
"""Intermediate responses generated by sub-agents to convey progress or status
in a multi-agent system, distinct from the final response.

This is expressed as a Tuple of:
This is expressed as a tuple of:
- Author: Usually the sub-agent name that generated the intermediate
response.

Expand Down Expand Up @@ -71,6 +81,9 @@ class Invocation(EvalBaseModel):
creation_timestamp: float = 0.0
"""Timestamp for the current invocation, primarily intended for debugging purposes."""

function_api_spec: list[FunctionSpec] = Field(default_factory=list)
"""Function API spec for the invocation."""


class SessionInput(EvalBaseModel):
"""Values that help initialize a Session."""
Expand Down
21 changes: 21 additions & 0 deletions src/google/adk/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,30 @@

from typing import Optional

from google.genai import types as genai_types
from pydantic import alias_generators
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field

from .eval_case import Invocation
from .evaluator import EvalStatus


class JudgeModelOptions(BaseModel):
"""Options for an eval metric's judge model."""

judge_model: str
"""The judge model to use for evaluation. It can be a model name."""

judge_model_config: Optional[genai_types.GenerateContentConfig] = Field(
default=None, description="""The configuration for the judge model."""
)

num_samples: int = 1
"""The number of samples to generate for each invocation."""


class EvalMetric(BaseModel):
"""A metric used to evaluate a particular aspect of an eval case."""

Expand All @@ -38,6 +54,11 @@ class EvalMetric(BaseModel):
threshold: float
"""A threshold value. Each metric decides how to interpret this threshold."""

judge_model_options: Optional[JudgeModelOptions] = Field(
default=None,
description="""Options for the judge model, if applicable.""",
)


class EvalMetricResult(EvalMetric):
"""The actual computed score/value of a particular EvalMetric."""
Expand Down
13 changes: 13 additions & 0 deletions src/google/adk/evaluation/evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from .eval_case import EvalCase
from .eval_case import FunctionSpec
from .eval_case import IntermediateData
from .eval_case import Invocation
from .eval_case import SessionInput
Expand Down Expand Up @@ -153,6 +154,17 @@ async def _generate_inferences_from_root_agent(
user_id = initial_session.user_id if initial_session else "test_user_id"
session_id = session_id if session_id else str(uuid.uuid4())

function_specs = []
if hasattr(root_agent, "canonical_tools"):
tools = await root_agent.canonical_tools()
for tool in tools:
function_specs.append(
FunctionSpec(
name=tool.name,
description=tool.description,
)
)

_ = await session_service.create_session(
app_name=app_name,
user_id=user_id,
Expand Down Expand Up @@ -201,6 +213,7 @@ async def _generate_inferences_from_root_agent(
user_content=user_content,
final_response=final_response,
intermediate_data=IntermediateData(tool_uses=tool_uses),
function_api_spec=function_specs,
)
)

Expand Down
173 changes: 173 additions & 0 deletions src/google/adk/evaluation/final_response_match_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
import logging
import re
from typing import Optional

from typing_extensions import override

from ..models.llm_response import LlmResponse
from .eval_case import FunctionSpec
from .eval_case import Invocation
from .eval_metrics import EvalMetric
from .evaluator import EvalStatus
from .evaluator import EvaluationResult
from .evaluator import PerInvocationResult
from .llm_as_judge import get_eval_status
from .llm_as_judge import get_text_from_content
from .llm_as_judge import LlmAsJudge

logger = logging.getLogger("google_adk." + __name__)

FINAL_RESPONSE_MATCH_V2_PROMPT = """
You are an expert rater for an AI agent. The AI agent is going to call an API to answer the user query and generate API tool use code based for the choice of the API and API arguments. The ideal model response should be a function call that fulfills user query, or a natural language response hedges or asks users for further clarification if a function call does not apply.
The primary focus of this rating task is to check correctness of the model responses.

The data consists of:
- A set of python function definitions available to the agent.
- A user query.
- A model generated response for the prompt. The responses can consist of:
- Natural language, when the model is asking for clarification, or tells the user it does not possess the requested functionality / option.
- Code, in the form of one or multiple python function calls, and additional code as needed, for when the model is fulfilling the user request.
You can use the help from a reference response annotated by a human rater. This reference response is of high quality. You can compare the agent's response with the reference response and decide if the agent's response is valid.
Note sometimes the reference response only contains the key entities of the correct answer and you need to be flexible to allow the agent response to contain more information than the reference response, or to present the key entities in a different format or structure or in shorter or longer format.
When the agent response is provided in the form of tables/dataframes or should be best provided in the form of tables/dataframes: focus on the key entities and main components requested in the user query and check whether you can retrieve those from the agent response. Likewise, if you have the reference response, then find out the key entities and main components in them and check whether you can retrieve those from the agent response. If the prompt does not specify any format instructions and the main items/components are included in the response then tolerate the differences in the formatting of those tables/dataframes.

You should follow the constitutions below very carefully to rate the model response:
- Allow flexibility of format even when reference code only uses one of the possible format, unless API spec or user prompt has explicit format requirement
- e.g. For state name, allow both abbreviation and full name unless API spec has explicit requirement. e.g. both 'tx' and 'Texas' should be allowed in the agent response even when reference code only uses one of them.
- e.g. If a reference response list outputs in a list format, the agent response is allowed to use sentence format and vice versa unless user prompt explicitly asks for a specific format.
- e.g. For numbers, allow flexibility of formatting, e.g. 1000000 vs 1,000,000.
- The model shouldn't assume that it doesn't have access to according data or incapable of answering the question if reference response is able to find a legit answer.
- If the model response contains the correct final answer, rate it as valid even when the model response contains more information than the reference response.
- If the user prompt has csv or other table format data, don't read it yourself. Trust the reference response final answer instead.
- When the validation needs maths, date calculations, do not use your own calculator. Trust the reference response final answer instead.
- Be mindful about unit of numbers. For example, if the reference response says 100 miles, but the model response says 100 km, it is invalid.
- When the agent response or the reference response is provided in the form of tables/dataframes: focus on the key entities and main components requested in the user query and check whether you can retrieve those from the agent response and whether those match the reference response. If the user query does not specify any format instructions and the main items/components are included in the response then tolerate the differences in the formatting of those tables/dataframes.
- When the answer is in numeric format, check whether there are any format requirements in the numeric format, rounding, precision, number of decimals, etc. specified in the user query and the prompt. If there are no such instructions, then tolerate different numerical formats.
- When the answer is in numeric format and there are rounding or precision differences between the agent response and the reference response, if no further instructions are provided evaluate if the rounding strategy or precision in the agent response follows the standards for that entity. For instance, model accuracy scores must be reported with at least two decimal places (e.g., 0.798 → 0.80 is acceptable, but 0.7 is not).

Below are the inputs:
{{
"Function API spec": {function_api_spec},
"User prompt": {prompt},
"Agent response": {response},
"Reference response": {golden_response},
}}

The answer should be a json alone which follows the json structure below:
{{
"is_the_agent_response_valid": [valid or invalid],
"reasoning":
}}
Answer with assertiveness:
"""


def _extract_validity_label(response_text: str) -> Optional[str]:
label_match_is_response_valid = re.search(
r'"is_the_agent_response_valid":\s*\[*[\n\s]*"*([^"^\]^\s]*)"*[\n\s]*\]*\s*[,\n\}]',
response_text,
)
label_match_is_response_invalid = re.search(
r'"is_the_agent_response_invalid":\s*\[*[\n\s]*"*([^"^\]^\s]*)"*[\n\s]*\]*\s*[,\n\}]',
response_text,
)
if label_match_is_response_valid:
return label_match_is_response_valid.group(1)
elif label_match_is_response_invalid:
return label_match_is_response_invalid.group(1)
else:
return None


def _format_function_api_spec(functions: list[FunctionSpec]) -> str:
function_api_spec = []
for function in functions:
function_spec = {
"Function name": function.name,
"Function description": function.description,
}
function_api_spec.append(function_spec)
return json.dumps(function_api_spec)


class FinalResponseMatchV2Evaluator(LlmAsJudge):
"""V2 final response match evaluator which uses an LLM to judge responses."""

def __init__(
self,
eval_metric: EvalMetric,
):
super().__init__(eval_metric)
self._auto_rater_prompt_template = FINAL_RESPONSE_MATCH_V2_PROMPT

@override
def format_auto_rater_prompt(
self, actual_invocation: Invocation, expected_invocation: Invocation
) -> str:
reference = get_text_from_content(expected_invocation.final_response)
response = get_text_from_content(actual_invocation.final_response)
user_prompt = get_text_from_content(expected_invocation.user_content)
function_api_spec = _format_function_api_spec(
actual_invocation.function_api_spec
)
return self._auto_rater_prompt_template.format(
function_api_spec=function_api_spec,
prompt=user_prompt,
response=response,
golden_response=reference,
)

@override
def convert_auto_rater_response_to_score(
self, llm_response: LlmResponse
) -> Optional[float]:
try:
response_text = get_text_from_content(llm_response.content).strip()
label = _extract_validity_label(response_text)
except json.JSONDecodeError:
logger.error("Failed to parse auto rater response: %s", llm_response)
return None
if label == "valid":
return 1.0
elif label == "invalid":
return 0.0
else:
return None

@override
def aggregate_invocation_results(
self, per_invocation_results: list[PerInvocationResult]
) -> EvaluationResult:
"""Computes the fraction of invocation results that are valid."""
num_valid = 0
num_evaluated = 0
for result in per_invocation_results:
if result.score is None or result.eval_status == EvalStatus.NOT_EVALUATED:
continue
num_evaluated += 1
num_valid += result.score
overall_score = num_valid / num_evaluated
return EvaluationResult(
overall_score=overall_score,
overall_eval_status=get_eval_status(
overall_score, self._eval_metric.threshold
),
per_invocation_results=per_invocation_results,
)
Loading