Skip to content

Commit 6c999ca

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Introduce write protected mode to BigQuery tools
This allows to protect against any write operations (e.g. update or delete a table), useful for some agents that must only be used in a read-only mode, while the user may have write permissions. PiperOrigin-RevId: 769803741
1 parent 77f44a4 commit 6c999ca

File tree

10 files changed

+490
-51
lines changed

10 files changed

+490
-51
lines changed

contributing/samples/bigquery/agent.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
from google.adk.agents import llm_agent
1818
from google.adk.tools.bigquery import BigQueryCredentialsConfig
1919
from google.adk.tools.bigquery import BigQueryToolset
20+
from google.adk.tools.bigquery.config import BigQueryToolConfig
21+
from google.adk.tools.bigquery.config import WriteMode
2022
import google.auth
2123

2224
RUN_WITH_ADC = False
2325

2426

27+
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
28+
2529
if RUN_WITH_ADC:
2630
# Initialize the tools to use the application default credentials.
2731
application_default_credentials, _ = google.auth.default()
@@ -37,7 +41,9 @@
3741
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
3842
)
3943

40-
bigquery_toolset = BigQueryToolset(credentials_config=credentials_config)
44+
bigquery_toolset = BigQueryToolset(
45+
credentials_config=credentials_config, bigquery_tool_config=tool_config
46+
)
4147

4248
# The variable name `root_agent` determines what your root agent is for the
4349
# debug CLI

src/google/adk/tools/bigquery/bigquery_tool.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..tool_context import ToolContext
2626
from .bigquery_credentials import BigQueryCredentialsConfig
2727
from .bigquery_credentials import BigQueryCredentialsManager
28+
from .config import BigQueryToolConfig
2829

2930

3031
class BigQueryTool(FunctionTool):
@@ -41,21 +42,27 @@ class BigQueryTool(FunctionTool):
4142
def __init__(
4243
self,
4344
func: Callable[..., Any],
44-
credentials: Optional[BigQueryCredentialsConfig] = None,
45+
*,
46+
credentials_config: Optional[BigQueryCredentialsConfig] = None,
47+
bigquery_tool_config: Optional[BigQueryToolConfig] = None,
4548
):
4649
"""Initialize the Google API tool.
4750
4851
Args:
4952
func: callable that impelments the tool's logic, can accept one
5053
'credential" parameter
51-
credentials: credentials used to call Google API. If None, then we don't
52-
hanlde the auth logic
54+
credentials_config: credentials config used to call Google API. If None,
55+
then we don't hanlde the auth logic
5356
"""
5457
super().__init__(func=func)
5558
self._ignore_params.append("credentials")
56-
self.credentials_manager = (
57-
BigQueryCredentialsManager(credentials) if credentials else None
59+
self._ignore_params.append("config")
60+
self._credentials_manager = (
61+
BigQueryCredentialsManager(credentials_config)
62+
if credentials_config
63+
else None
5864
)
65+
self._tool_config = bigquery_tool_config
5966

6067
@override
6168
async def run_async(
@@ -69,12 +76,12 @@ async def run_async(
6976
try:
7077
# Get valid credentials
7178
credentials = (
72-
await self.credentials_manager.get_valid_credentials(tool_context)
73-
if self.credentials_manager
79+
await self._credentials_manager.get_valid_credentials(tool_context)
80+
if self._credentials_manager
7481
else None
7582
)
7683

77-
if credentials is None and self.credentials_manager:
84+
if credentials is None and self._credentials_manager:
7885
# OAuth flow in progress
7986
return (
8087
"User authorization is required to access Google services for"
@@ -84,7 +91,7 @@ async def run_async(
8491
# Execute the tool's specific logic with valid credentials
8592

8693
return await self._run_async_with_credential(
87-
credentials, args, tool_context
94+
credentials, self._tool_config, args, tool_context
8895
)
8996

9097
except Exception as ex:
@@ -96,6 +103,7 @@ async def run_async(
96103
async def _run_async_with_credential(
97104
self,
98105
credentials: Credentials,
106+
tool_config: BigQueryToolConfig,
99107
args: dict[str, Any],
100108
tool_context: ToolContext,
101109
) -> Any:
@@ -113,4 +121,6 @@ async def _run_async_with_credential(
113121
signature = inspect.signature(self.func)
114122
if "credentials" in signature.parameters:
115123
args_to_call["credentials"] = credentials
124+
if "config" in signature.parameters:
125+
args_to_call["config"] = tool_config
116126
return await super().run_async(args=args_to_call, tool_context=tool_context)

src/google/adk/tools/bigquery/bigquery_toolset.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ...tools.base_toolset import ToolPredicate
2929
from .bigquery_credentials import BigQueryCredentialsConfig
3030
from .bigquery_tool import BigQueryTool
31+
from .config import BigQueryToolConfig
3132

3233

3334
class BigQueryToolset(BaseToolset):
@@ -38,9 +39,11 @@ def __init__(
3839
*,
3940
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
4041
credentials_config: Optional[BigQueryCredentialsConfig] = None,
42+
bigquery_tool_config: Optional[BigQueryToolConfig] = None,
4143
):
42-
self._credentials_config = credentials_config
4344
self.tool_filter = tool_filter
45+
self._credentials_config = credentials_config
46+
self._tool_config = bigquery_tool_config
4447

4548
def _is_tool_selected(
4649
self, tool: BaseTool, readonly_context: ReadonlyContext
@@ -64,14 +67,15 @@ async def get_tools(
6467
all_tools = [
6568
BigQueryTool(
6669
func=func,
67-
credentials=self._credentials_config,
70+
credentials_config=self._credentials_config,
71+
bigquery_tool_config=self._tool_config,
6872
)
6973
for func in [
7074
metadata_tool.get_dataset_info,
7175
metadata_tool.get_table_info,
7276
metadata_tool.list_dataset_ids,
7377
metadata_tool.list_table_ids,
74-
query_tool.execute_sql,
78+
query_tool.get_execute_sql(self._tool_config),
7579
]
7680
]
7781

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from enum import Enum
18+
19+
from pydantic import BaseModel
20+
21+
from ...utils.feature_decorator import experimental
22+
23+
24+
class WriteMode(Enum):
25+
"""Write mode indicating what levels of write operations are allowed in BigQuery."""
26+
27+
BLOCKED = 'blocked'
28+
"""No write operations are allowed.
29+
30+
This mode implies that only read (i.e. SELECT query) operations are allowed.
31+
"""
32+
33+
ALLOWED = 'allowed'
34+
"""All write operations are allowed."""
35+
36+
37+
@experimental('Config defaults may have breaking change in the future.')
38+
class BigQueryToolConfig(BaseModel):
39+
"""Configuration for BigQuery tools."""
40+
41+
write_mode: WriteMode = WriteMode.BLOCKED
42+
"""Write mode for BigQuery tools.
43+
44+
By default, the tool will allow only read operations. This behaviour may
45+
change in future versions.
46+
"""

src/google/adk/tools/bigquery/metadata_tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from google.cloud import bigquery
1616
from google.oauth2.credentials import Credentials
1717

18-
from ...tools.bigquery import client
18+
from . import client
1919

2020

2121
def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]:

src/google/adk/tools/bigquery/query_tool.py

Lines changed: 138 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
16+
import types
17+
from typing import Callable
18+
19+
from google.cloud import bigquery
1520
from google.oauth2.credentials import Credentials
1621

17-
from ...tools.bigquery import client
22+
from . import client
23+
from .config import BigQueryToolConfig
24+
from .config import WriteMode
1825

1926
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
2027

2128

22-
def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
29+
def execute_sql(
30+
project_id: str,
31+
query: str,
32+
credentials: Credentials,
33+
config: BigQueryToolConfig,
34+
) -> dict:
2335
"""Run a BigQuery SQL query in the project and return the result.
2436
2537
Args:
@@ -35,34 +47,49 @@ def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
3547
query not returned in the result.
3648
3749
Examples:
38-
>>> execute_sql("bigframes-dev",
39-
... "SELECT island, COUNT(*) AS population "
40-
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
41-
{
42-
"rows": [
43-
{
44-
"island": "Dream",
45-
"population": 124
46-
},
47-
{
48-
"island": "Biscoe",
49-
"population": 168
50-
},
51-
{
52-
"island": "Torgersen",
53-
"population": 52
54-
}
55-
]
56-
}
50+
Fetch data or insights from a table:
51+
52+
>>> execute_sql("bigframes-dev",
53+
... "SELECT island, COUNT(*) AS population "
54+
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
55+
{
56+
"status": "ERROR",
57+
"rows": [
58+
{
59+
"island": "Dream",
60+
"population": 124
61+
},
62+
{
63+
"island": "Biscoe",
64+
"population": 168
65+
},
66+
{
67+
"island": "Torgersen",
68+
"population": 52
69+
}
70+
]
71+
}
5772
"""
5873

5974
try:
6075
bq_client = client.get_bigquery_client(credentials=credentials)
76+
if not config or config.write_mode == WriteMode.BLOCKED:
77+
query_job = bq_client.query(
78+
query,
79+
project=project_id,
80+
job_config=bigquery.QueryJobConfig(dry_run=True),
81+
)
82+
if query_job.statement_type != "SELECT":
83+
return {
84+
"status": "ERROR",
85+
"error_details": "Read-only mode only supports SELECT statements.",
86+
}
87+
6188
row_iterator = bq_client.query_and_wait(
6289
query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
6390
)
6491
rows = [{key: val for key, val in row.items()} for row in row_iterator]
65-
result = {"rows": rows}
92+
result = {"status": "SUCCESS", "rows": rows}
6693
if (
6794
MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None
6895
and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS
@@ -74,3 +101,92 @@ def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
74101
"status": "ERROR",
75102
"error_details": str(ex),
76103
}
104+
105+
106+
_execute_sql_write_examples = """
107+
Create a table from the result of a query:
108+
109+
>>> execute_sql("bigframes-dev",
110+
... "CREATE TABLE my_project.my_dataset.my_table AS "
111+
... "SELECT island, COUNT(*) AS population "
112+
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
113+
{
114+
"status": "SUCCESS",
115+
"rows": []
116+
}
117+
118+
Delete a table:
119+
120+
>>> execute_sql("bigframes-dev",
121+
... "DROP TABLE my_project.my_dataset.my_table")
122+
{
123+
"status": "SUCCESS",
124+
"rows": []
125+
}
126+
127+
Copy a table to another table:
128+
129+
>>> execute_sql("bigframes-dev",
130+
... "CREATE TABLE my_project.my_dataset.my_table_clone "
131+
... "CLONE my_project.my_dataset.my_table")
132+
{
133+
"status": "SUCCESS",
134+
"rows": []
135+
}
136+
137+
Create a snapshot (a lightweight, read-optimized copy) of en existing
138+
table:
139+
140+
>>> execute_sql("bigframes-dev",
141+
... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot "
142+
... "CLONE my_project.my_dataset.my_table")
143+
{
144+
"status": "SUCCESS",
145+
"rows": []
146+
}
147+
148+
Notes:
149+
- If a destination table already exists, there are a few ways to overwrite
150+
it:
151+
- Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE".
152+
- First run "DROP TABLE", followed by "CREATE TABLE".
153+
- To insert data into a table, use "INSERT INTO" statement.
154+
"""
155+
156+
157+
def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]:
158+
"""Get the execute_sql tool customized as per the given tool config.
159+
160+
Args:
161+
config: BigQuery tool configuration indicating the behavior of the
162+
execute_sql tool.
163+
164+
Returns:
165+
callable[..., dict]: A version of the execute_sql tool respecting the tool
166+
config.
167+
"""
168+
169+
if not config or config.write_mode == WriteMode.BLOCKED:
170+
return execute_sql
171+
172+
# Create a new function object using the original function's code and globals.
173+
# We pass the original code, globals, name, defaults, and closure.
174+
# This creates a raw function object without copying other metadata yet.
175+
execute_sql_wrapper = types.FunctionType(
176+
execute_sql.__code__,
177+
execute_sql.__globals__,
178+
execute_sql.__name__,
179+
execute_sql.__defaults__,
180+
execute_sql.__closure__,
181+
)
182+
183+
# Use functools.update_wrapper to copy over other essential attributes
184+
# from the original function to the new one.
185+
# This includes __name__, __qualname__, __module__, __annotations__, etc.
186+
# It specifically allows us to then set __doc__ separately.
187+
functools.update_wrapper(execute_sql_wrapper, execute_sql)
188+
189+
# Now, set the new docstring
190+
execute_sql_wrapper.__doc__ += _execute_sql_write_examples
191+
192+
return execute_sql_wrapper

0 commit comments

Comments
 (0)