@@ -174,3 +174,40 @@ async def test_gemini_credentials(self, is_async):
174
174
)
175
175
assert token == ""
176
176
assert project == ""
177
+
178
+
179
+ @patch ("litellm.llms.vertex_ai.vertex_llm_base.aws.Credentials.from_info" )
180
+ @patch ("litellm.llms.vertex_ai.vertex_llm_base.identity_pool.Credentials.from_info" )
181
+ def test_load_auth_wif (self , mock_identity_pool , mock_aws ):
182
+ vertex_base = VertexBase ()
183
+ input_project_id = "some_project_id"
184
+
185
+ # Test case 1: Using Workload Identity Federation for Microsoft Azure and
186
+ # OIDC identity providers (default behavior)
187
+ json_obj_1 = {
188
+ "type" : "external_account" ,
189
+ }
190
+ mock_creds_1 = MagicMock ()
191
+ mock_identity_pool .return_value = mock_creds_1
192
+
193
+ creds_1 , project_id = vertex_base .load_auth (
194
+ credentials = json_obj_1 , project_id = input_project_id
195
+ )
196
+ mock_identity_pool .assert_called_once_with (json_obj_1 )
197
+ assert creds_1 == mock_creds_1
198
+ assert project_id == input_project_id
199
+
200
+ # Test case 2: Using Workload Identity Federation for AWS
201
+ json_obj_2 = {
202
+ "type" : "external_account" ,
203
+ "credential_source" : {"environment_id" : "aws1" }
204
+ }
205
+ mock_creds_2 = MagicMock ()
206
+ mock_aws .return_value = mock_creds_2
207
+
208
+ creds_2 , project_id = vertex_base .load_auth (
209
+ credentials = json_obj_2 , project_id = input_project_id
210
+ )
211
+ mock_aws .assert_called_once_with (json_obj_2 )
212
+ assert creds_2 == mock_creds_2
213
+ assert project_id == input_project_id
0 commit comments