Skip to content

Commit f6f0e7c

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: add customized bigquer tool wrapper class to facilitate developer to handcraft bigquery api tool
PiperOrigin-RevId: 762626700
1 parent 5630973 commit f6f0e7c

File tree

8 files changed

+1036
-1
lines changed

8 files changed

+1036
-1
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""BigQuery Tools. (Experimental)
16+
17+
BigQuery Tools under this module are hand crafted and customized while the tools
18+
under google.adk.tools.google_api_tool are auto generated based on API
19+
definition. The rationales to have customized tool are:
20+
21+
1. BigQuery APIs have functions overlaps and LLM can't tell what tool to use
22+
2. BigQuery APIs have a lot of parameters with some rarely used, which are not
23+
LLM-friendly
24+
3. We want to provide more high-level tools like forecasting, RAG, segmentation,
25+
etc.
26+
4. We want to provide extra access guardrails in those tools. For example,
27+
execute_sql can't arbitrarily mutate existing data.
28+
"""
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List
16+
from typing import Optional
17+
18+
from fastapi.openapi.models import OAuth2
19+
from fastapi.openapi.models import OAuthFlowAuthorizationCode
20+
from fastapi.openapi.models import OAuthFlows
21+
from google.auth.exceptions import RefreshError
22+
from google.auth.transport.requests import Request
23+
from google.oauth2.credentials import Credentials
24+
from pydantic import BaseModel
25+
from pydantic import model_validator
26+
27+
from ...auth import AuthConfig
28+
from ...auth import AuthCredential
29+
from ...auth import AuthCredentialTypes
30+
from ...auth import OAuth2Auth
31+
from ..tool_context import ToolContext
32+
33+
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
34+
35+
36+
class BigQueryCredentials(BaseModel):
37+
"""Configuration for Google API tools. (Experimental)"""
38+
39+
# Configure the model to allow arbitrary types like Credentials
40+
model_config = {"arbitrary_types_allowed": True}
41+
42+
credentials: Optional[Credentials] = None
43+
"""the existing oauth credentials to use. If set will override client ID,
44+
client secret, and scopes."""
45+
client_id: Optional[str] = None
46+
"""the oauth client ID to use."""
47+
client_secret: Optional[str] = None
48+
"""the oauth client secret to use."""
49+
scopes: Optional[List[str]] = None
50+
"""the scopes to use.
51+
"""
52+
53+
@model_validator(mode="after")
54+
def __post_init__(self) -> "BigQueryCredentials":
55+
"""Validate that either credentials or client ID/secret are provided."""
56+
if not self.credentials and (not self.client_id or not self.client_secret):
57+
raise ValueError(
58+
"Must provide either credentials or client_id abd client_secret pair."
59+
)
60+
if self.credentials:
61+
self.client_id = self.credentials.client_id
62+
self.client_secret = self.credentials.client_secret
63+
self.scopes = self.credentials.scopes
64+
return self
65+
66+
67+
class BigQueryCredentialsManager:
68+
"""Manages Google API credentials with automatic refresh and OAuth flow handling.
69+
70+
This class centralizes credential management so multiple tools can share
71+
the same authenticated session without duplicating OAuth logic.
72+
"""
73+
74+
def __init__(self, credentials: BigQueryCredentials):
75+
"""Initialize the credential manager.
76+
77+
Args:
78+
credential_config: Configuration containing OAuth details or existing
79+
credentials
80+
"""
81+
self.credentials = credentials
82+
83+
async def get_valid_credentials(
84+
self, tool_context: ToolContext
85+
) -> Optional[Credentials]:
86+
"""Get valid credentials, handling refresh and OAuth flow as needed.
87+
88+
Args:
89+
tool_context: The tool context for OAuth flow and state management
90+
required_scopes: Set of OAuth scopes required by the calling tool
91+
92+
Returns:
93+
Valid Credentials object, or None if OAuth flow is needed
94+
"""
95+
# First, try to get cached credentials from the instance
96+
creds = self.credentials.credentials
97+
98+
# If credentails are empty
99+
if not creds:
100+
creds = tool_context.get(BIGQUERY_TOKEN_CACHE_KEY, None)
101+
self.credentials.credentials = creds
102+
103+
# Check if we have valid credentials
104+
if creds and creds.valid:
105+
return creds
106+
107+
# Try to refresh expired credentials
108+
if creds and creds.expired and creds.refresh_token:
109+
try:
110+
creds.refresh(Request())
111+
if creds.valid:
112+
# Cache the refreshed credentials
113+
self.credentials.credentials = creds
114+
return creds
115+
except RefreshError:
116+
# Refresh failed, need to re-authenticate
117+
pass
118+
119+
# Need to perform OAuth flow
120+
return await self._perform_oauth_flow(tool_context)
121+
122+
async def _perform_oauth_flow(
123+
self, tool_context: ToolContext
124+
) -> Optional[Credentials]:
125+
"""Perform OAuth flow to get new credentials.
126+
127+
Args:
128+
tool_context: The tool context for OAuth flow
129+
required_scopes: Set of required OAuth scopes
130+
131+
Returns:
132+
New Credentials object, or None if flow is in progress
133+
"""
134+
135+
# Create OAuth configuration
136+
auth_scheme = OAuth2(
137+
flows=OAuthFlows(
138+
authorizationCode=OAuthFlowAuthorizationCode(
139+
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
140+
tokenUrl="https://oauth2.googleapis.com/token",
141+
scopes={
142+
scope: f"Access to {scope}"
143+
for scope in self.credentials.scopes
144+
},
145+
)
146+
)
147+
)
148+
149+
auth_credential = AuthCredential(
150+
auth_type=AuthCredentialTypes.OAUTH2,
151+
oauth2=OAuth2Auth(
152+
client_id=self.credentials.client_id,
153+
client_secret=self.credentials.client_secret,
154+
),
155+
)
156+
157+
# Check if OAuth response is available
158+
auth_response = tool_context.get_auth_response(
159+
AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential)
160+
)
161+
162+
if auth_response:
163+
# OAuth flow completed, create credentials
164+
creds = Credentials(
165+
token=auth_response.oauth2.access_token,
166+
refresh_token=auth_response.oauth2.refresh_token,
167+
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
168+
client_id=self.credentials.client_id,
169+
client_secret=self.credentials.client_secret,
170+
scopes=list(self.credentials.scopes),
171+
)
172+
173+
# Cache the new credentials
174+
self.credentials.credentials = creds
175+
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds
176+
return creds
177+
else:
178+
# Request OAuth flow
179+
tool_context.request_credential(
180+
AuthConfig(
181+
auth_scheme=auth_scheme,
182+
raw_auth_credential=auth_credential,
183+
)
184+
)
185+
return None
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import inspect
17+
from typing import Any
18+
from typing import Callable
19+
from typing import Optional
20+
from typing import override
21+
22+
from google.oauth2.credentials import Credentials
23+
24+
from ..function_tool import FunctionTool
25+
from ..tool_context import ToolContext
26+
from .bigquery_credentials import BigQueryCredentials
27+
from .bigquery_credentials import BigQueryCredentialsManager
28+
29+
30+
class BigQueryTool(FunctionTool):
31+
"""GoogleApiTool class for tools that call Google APIs.
32+
33+
This class is for developers to handcraft customized Google API tools rather
34+
than auto generate Google API tools based on API specs.
35+
36+
This class handles all the OAuth complexity, credential management,
37+
and common Google API patterns so subclasses can focus on their
38+
specific functionality.
39+
"""
40+
41+
def __init__(
42+
self,
43+
func: Callable[..., Any],
44+
credentials: Optional[BigQueryCredentials] = None,
45+
):
46+
"""Initialize the Google API tool.
47+
48+
Args:
49+
func: callable that impelments the tool's logic, can accept one
50+
'credential" parameter
51+
credentials: credentials used to call Google API. If None, then we don't
52+
hanlde the auth logic
53+
"""
54+
super().__init__(func=func)
55+
self._ignore_params.append("credentials")
56+
self.credentials_manager = (
57+
BigQueryCredentialsManager(credentials) if credentials else None
58+
)
59+
60+
@override
61+
async def run_async(
62+
self, *, args: dict[str, Any], tool_context: ToolContext
63+
) -> Any:
64+
"""Main entry point for tool execution with credential handling.
65+
66+
This method handles all the OAuth complexity and then delegates
67+
to the subclass's run_async_with_credential method.
68+
"""
69+
try:
70+
# Get valid credentials
71+
credentials = (
72+
await self.credentials_manager.get_valid_credentials(tool_context)
73+
if self.credentials_manager
74+
else None
75+
)
76+
77+
if credentials is None and self.credentials_manager:
78+
# OAuth flow in progress
79+
return (
80+
"User authorization is required to access Google services for"
81+
f" {self.name}. Please complete the authorization flow."
82+
)
83+
84+
# Execute the tool's specific logic with valid credentials
85+
86+
return await self._run_async_with_credential(
87+
credentials, args, tool_context
88+
)
89+
90+
except Exception as ex:
91+
return {
92+
"status": "ERROR",
93+
"error_details": str(ex),
94+
}
95+
96+
async def _run_async_with_credential(
97+
self,
98+
credentials: Credentials,
99+
args: dict[str, Any],
100+
tool_context: ToolContext,
101+
) -> Any:
102+
"""Execute the tool's specific logic with valid credentials.
103+
104+
Args:
105+
credentials: Valid Google OAuth credentials
106+
args: Arguments passed to the tool
107+
tool_context: Tool execution context
108+
109+
Returns:
110+
The result of the tool execution
111+
"""
112+
args_to_call = args.copy()
113+
signature = inspect.signature(self.func)
114+
if "credentials" in signature.parameters:
115+
args_to_call["credentials"] = credentials
116+
return await super().run_async(args=args_to_call, tool_context=tool_context)

src/google/adk/tools/function_tool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self, func: Callable[..., Any]):
5757

5858
super().__init__(name=name, description=doc)
5959
self.func = func
60+
self._ignore_params = ['tool_context', 'input_stream']
6061

6162
@override
6263
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
@@ -65,7 +66,7 @@ def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
6566
func=self.func,
6667
# The model doesn't understand the function context.
6768
# input_stream is for streaming tool
68-
ignore_params=['tool_context', 'input_stream'],
69+
ignore_params=self._ignore_params,
6970
variant=self._api_variant,
7071
)
7172
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)