Skip to content

Commit e367bef

Browse files
committed
feat: add workload identity federation between GCP and AWS
1 parent 44264ab commit e367bef

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

litellm/llms/vertex_ai/vertex_llm_base.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,11 @@ def load_auth(
4141
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
4242
) -> Tuple[Any, str]:
4343
import google.auth as google_auth
44-
from google.auth import identity_pool
4544
from google.auth.transport.requests import (
4645
Request, # type: ignore[import-untyped]
4746
)
4847

4948
if credentials is not None:
50-
import google.oauth2.service_account
51-
5249
if isinstance(credentials, str):
5350
verbose_logger.debug(
5451
"Vertex: Loading vertex credentials from %s", credentials
@@ -80,8 +77,18 @@ def load_auth(
8077

8178
# Check if the JSON object contains Workload Identity Federation configuration
8279
if "type" in json_obj and json_obj["type"] == "external_account":
83-
creds = identity_pool.Credentials.from_info(json_obj)
80+
# If environment_id key contains "aws" value it corresponds to an AWS config file
81+
if (
82+
"credential_source" in json_obj
83+
and "environment_id" in json_obj["credential_source"]
84+
and "aws" in json_obj["credential_source"]["environment_id"]
85+
):
86+
creds = google_auth.aws.Credentials.from_info(json_obj)
87+
else:
88+
creds = google_auth.identity_pool.Credentials.from_info(json_obj)
8489
else:
90+
import google.oauth2.service_account
91+
8592
creds = (
8693
google.oauth2.service_account.Credentials.from_service_account_info(
8794
json_obj,

tests/litellm/llms/vertex_ai/test_vertex_llm_base.py

+54
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,57 @@ async def test_gemini_credentials(self, is_async):
174174
)
175175
assert token == ""
176176
assert project == ""
177+
178+
def test_load_auth_wif(self):
179+
vertex_base = VertexBase()
180+
input_project_id = "some_project_id"
181+
182+
# Test case 1: Using Workload Identity Federation for Microsoft Azure and
183+
# OIDC identity providers (default behavior)
184+
json_obj_1 = {
185+
"type": "external_account",
186+
}
187+
mock_auth_1 = MagicMock()
188+
mock_creds_1 = MagicMock()
189+
mock_request_1 = MagicMock()
190+
mock_creds_1 = mock_auth_1.identity_pool.Credentials.from_info.return_value
191+
with patch.dict(sys.modules, {"google.auth": mock_auth_1,
192+
"google.auth.transport.requests": mock_request_1}):
193+
194+
creds_1, project_id = vertex_base.load_auth(
195+
credentials=json_obj_1, project_id=input_project_id
196+
)
197+
198+
mock_auth_1.identity_pool.Credentials.from_info.assert_called_once_with(
199+
json_obj_1
200+
)
201+
mock_creds_1.refresh.assert_called_once_with(
202+
mock_request_1.Request.return_value
203+
)
204+
assert creds_1 == mock_creds_1
205+
assert project_id == input_project_id
206+
207+
# Test case 2: Using Workload Identity Federation for AWS
208+
json_obj_2 = {
209+
"type": "external_account",
210+
"credential_source": {"environment_id": "aws1"}
211+
}
212+
mock_auth_2 = MagicMock()
213+
mock_creds_2 = MagicMock()
214+
mock_request_2 = MagicMock()
215+
mock_creds_2 = mock_auth_2.aws.Credentials.from_info.return_value
216+
with patch.dict(sys.modules, {"google.auth": mock_auth_2,
217+
"google.auth.transport.requests": mock_request_2}):
218+
219+
creds_2, project_id = vertex_base.load_auth(
220+
credentials=json_obj_2, project_id=input_project_id
221+
)
222+
223+
mock_auth_2.aws.Credentials.from_info.assert_called_once_with(
224+
json_obj_2
225+
)
226+
mock_creds_2.refresh.assert_called_once_with(
227+
mock_request_2.Request.return_value
228+
)
229+
assert creds_2 == mock_creds_2
230+
assert project_id == input_project_id

0 commit comments

Comments
 (0)