12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ from __future__ import annotations
16
+
15
17
from typing import List
16
18
from typing import Optional
17
19
33
35
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
34
36
35
37
36
- class BigQueryCredentials (BaseModel ):
38
+ class BigQueryCredentialsConfig (BaseModel ):
37
39
"""Configuration for Google API tools. (Experimental)"""
38
40
39
41
# Configure the model to allow arbitrary types like Credentials
40
42
model_config = {"arbitrary_types_allowed" : True }
41
43
42
44
credentials : Optional [Credentials ] = None
43
- """the existing oauth credentials to use. If set will override client ID,
44
- client secret, and scopes."""
45
+ """the existing oauth credentials to use. If set,this credential will be used
46
+ for every end user, end users don't need to be involved in the oauthflow. This
47
+ field is mutually exclusive with client_id, client_secret and scopes.
48
+ Don't set this field unless you are sure this credential has the permission to
49
+ access every end user's data.
50
+
51
+ Example usage: when the agent is deployed in Google Cloud environment and
52
+ the service account (used as application default credentials) has access to
53
+ all the required BigQuery resource. Setting this credential to allow user to
54
+ access the BigQuery resource without end users going through oauth flow.
55
+
56
+ To get application default credential: `google.auth.default(...)`. See more
57
+ details in https://cloud.google.com/docs/authentication/application-default-credentials.
58
+
59
+ When the deployed environment cannot provide a pre-existing credential,
60
+ consider setting below client_id, client_secret and scope for end users to go
61
+ through oauth flow, so that agent can access the user data.
62
+ """
45
63
client_id : Optional [str ] = None
46
64
"""the oauth client ID to use."""
47
65
client_secret : Optional [str ] = None
@@ -51,12 +69,20 @@ class BigQueryCredentials(BaseModel):
51
69
"""
52
70
53
71
@model_validator (mode = "after" )
54
- def __post_init__ (self ) -> "BigQueryCredentials" :
72
+ def __post_init__ (self ) -> BigQueryCredentialsConfig :
55
73
"""Validate that either credentials or client ID/secret are provided."""
56
74
if not self .credentials and (not self .client_id or not self .client_secret ):
57
75
raise ValueError (
58
76
"Must provide either credentials or client_id abd client_secret pair."
59
77
)
78
+ if self .credentials and (
79
+ self .client_id or self .client_secret or self .scopes
80
+ ):
81
+ raise ValueError (
82
+ "Cannot provide both existing credentials and"
83
+ " client_id/client_secret/scopes."
84
+ )
85
+
60
86
if self .credentials :
61
87
self .client_id = self .credentials .client_id
62
88
self .client_secret = self .credentials .client_secret
@@ -71,14 +97,14 @@ class BigQueryCredentialsManager:
71
97
the same authenticated session without duplicating OAuth logic.
72
98
"""
73
99
74
- def __init__ (self , credentials : BigQueryCredentials ):
100
+ def __init__ (self , credentials_config : BigQueryCredentialsConfig ):
75
101
"""Initialize the credential manager.
76
102
77
103
Args:
78
- credential_config: Configuration containing OAuth details or existing
79
- credentials
104
+ credentials_config: Credentials containing client id and client secrete
105
+ or default credentials
80
106
"""
81
- self .credentials = credentials
107
+ self .credentials_config = credentials_config
82
108
83
109
async def get_valid_credentials (
84
110
self , tool_context : ToolContext
@@ -87,18 +113,23 @@ async def get_valid_credentials(
87
113
88
114
Args:
89
115
tool_context: The tool context for OAuth flow and state management
90
- required_scopes: Set of OAuth scopes required by the calling tool
91
116
92
117
Returns:
93
118
Valid Credentials object, or None if OAuth flow is needed
94
119
"""
95
- # First, try to get cached credentials from the instance
96
- creds = self .credentials .credentials
120
+ # First, try to get credentials from the tool context
121
+ creds_json = tool_context .state .get (BIGQUERY_TOKEN_CACHE_KEY , None )
122
+ creds = (
123
+ Credentials .from_authorized_user_info (
124
+ creds_json , self .credentials_config .scopes
125
+ )
126
+ if creds_json
127
+ else None
128
+ )
97
129
98
- # If credentails are empty
130
+ # If credentails are empty use the default credential
99
131
if not creds :
100
- creds = tool_context .get (BIGQUERY_TOKEN_CACHE_KEY , None )
101
- self .credentials .credentials = creds
132
+ creds = self .credentials_config .credentials
102
133
103
134
# Check if we have valid credentials
104
135
if creds and creds .valid :
@@ -110,7 +141,7 @@ async def get_valid_credentials(
110
141
creds .refresh (Request ())
111
142
if creds .valid :
112
143
# Cache the refreshed credentials
113
- self . credentials . credentials = creds
144
+ tool_context . state [ BIGQUERY_TOKEN_CACHE_KEY ] = creds . to_json ()
114
145
return creds
115
146
except RefreshError :
116
147
# Refresh failed, need to re-authenticate
@@ -140,7 +171,7 @@ async def _perform_oauth_flow(
140
171
tokenUrl = "https://oauth2.googleapis.com/token" ,
141
172
scopes = {
142
173
scope : f"Access to { scope } "
143
- for scope in self .credentials .scopes
174
+ for scope in self .credentials_config .scopes
144
175
},
145
176
)
146
177
)
@@ -149,8 +180,8 @@ async def _perform_oauth_flow(
149
180
auth_credential = AuthCredential (
150
181
auth_type = AuthCredentialTypes .OAUTH2 ,
151
182
oauth2 = OAuth2Auth (
152
- client_id = self .credentials .client_id ,
153
- client_secret = self .credentials .client_secret ,
183
+ client_id = self .credentials_config .client_id ,
184
+ client_secret = self .credentials_config .client_secret ,
154
185
),
155
186
)
156
187
@@ -165,14 +196,14 @@ async def _perform_oauth_flow(
165
196
token = auth_response .oauth2 .access_token ,
166
197
refresh_token = auth_response .oauth2 .refresh_token ,
167
198
token_uri = auth_scheme .flows .authorizationCode .tokenUrl ,
168
- client_id = self .credentials .client_id ,
169
- client_secret = self .credentials .client_secret ,
170
- scopes = list (self .credentials .scopes ),
199
+ client_id = self .credentials_config .client_id ,
200
+ client_secret = self .credentials_config .client_secret ,
201
+ scopes = list (self .credentials_config .scopes ),
171
202
)
172
203
173
204
# Cache the new credentials
174
- self . credentials . credentials = creds
175
- tool_context . state [ BIGQUERY_TOKEN_CACHE_KEY ] = creds
205
+ tool_context . state [ BIGQUERY_TOKEN_CACHE_KEY ] = creds . to_json ()
206
+
176
207
return creds
177
208
else :
178
209
# Request OAuth flow
0 commit comments