@@ -174,3 +174,57 @@ async def test_gemini_credentials(self, is_async):
174
174
)
175
175
assert token == ""
176
176
assert project == ""
177
+
178
+ def test_load_auth_wif (self ):
179
+ vertex_base = VertexBase ()
180
+ input_project_id = "some_project_id"
181
+
182
+ # Test case 1: Using Workload Identity Federation for Microsoft Azure and
183
+ # OIDC identity providers (default behavior)
184
+ json_obj_1 = {
185
+ "type" : "external_account" ,
186
+ }
187
+ mock_auth_1 = MagicMock ()
188
+ mock_creds_1 = MagicMock ()
189
+ mock_request_1 = MagicMock ()
190
+ mock_creds_1 = mock_auth_1 .identity_pool .Credentials .from_info .return_value
191
+ with patch .dict (sys .modules , {"google.auth" : mock_auth_1 ,
192
+ "google.auth.transport.requests" : mock_request_1 }):
193
+
194
+ creds_1 , project_id = vertex_base .load_auth (
195
+ credentials = json_obj_1 , project_id = input_project_id
196
+ )
197
+
198
+ mock_auth_1 .identity_pool .Credentials .from_info .assert_called_once_with (
199
+ json_obj_1
200
+ )
201
+ mock_creds_1 .refresh .assert_called_once_with (
202
+ mock_request_1 .Request .return_value
203
+ )
204
+ assert creds_1 == mock_creds_1
205
+ assert project_id == input_project_id
206
+
207
+ # Test case 2: Using Workload Identity Federation for AWS
208
+ json_obj_2 = {
209
+ "type" : "external_account" ,
210
+ "credential_source" : {"environment_id" : "aws1" }
211
+ }
212
+ mock_auth_2 = MagicMock ()
213
+ mock_creds_2 = MagicMock ()
214
+ mock_request_2 = MagicMock ()
215
+ mock_creds_2 = mock_auth_2 .aws .Credentials .from_info .return_value
216
+ with patch .dict (sys .modules , {"google.auth" : mock_auth_2 ,
217
+ "google.auth.transport.requests" : mock_request_2 }):
218
+
219
+ creds_2 , project_id = vertex_base .load_auth (
220
+ credentials = json_obj_2 , project_id = input_project_id
221
+ )
222
+
223
+ mock_auth_2 .aws .Credentials .from_info .assert_called_once_with (
224
+ json_obj_2
225
+ )
226
+ mock_creds_2 .refresh .assert_called_once_with (
227
+ mock_request_2 .Request .return_value
228
+ )
229
+ assert creds_2 == mock_creds_2
230
+ assert project_id == input_project_id
0 commit comments