Skip to content

Commit 4876edb

Browse files
committed
feat: add workload identity federation between GCP and AWS
1 parent 44264ab commit 4876edb

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

litellm/llms/vertex_ai/vertex_llm_base.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from litellm.litellm_core_utils.asyncify import asyncify
1313
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
1414
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
15+
from google.auth import aws, identity_pool
1516

1617
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
1718

@@ -41,7 +42,6 @@ def load_auth(
4142
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
4243
) -> Tuple[Any, str]:
4344
import google.auth as google_auth
44-
from google.auth import identity_pool
4545
from google.auth.transport.requests import (
4646
Request, # type: ignore[import-untyped]
4747
)
@@ -80,7 +80,15 @@ def load_auth(
8080

8181
# Check if the JSON object contains Workload Identity Federation configuration
8282
if "type" in json_obj and json_obj["type"] == "external_account":
83-
creds = identity_pool.Credentials.from_info(json_obj)
83+
# If environment_id key contains "aws" value it corresponds to an AWS config file
84+
if (
85+
"credential_source" in json_obj
86+
and "environment_id" in json_obj["credential_source"]
87+
and "aws" in json_obj["credential_source"]["environment_id"]
88+
):
89+
creds = aws.Credentials.from_info(json_obj)
90+
else:
91+
creds = identity_pool.Credentials.from_info(json_obj)
8492
else:
8593
creds = (
8694
google.oauth2.service_account.Credentials.from_service_account_info(

tests/litellm/llms/vertex_ai/test_vertex_llm_base.py

+37
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,40 @@ async def test_gemini_credentials(self, is_async):
174174
)
175175
assert token == ""
176176
assert project == ""
177+
178+
179+
@patch("litellm.llms.vertex_ai.vertex_llm_base.aws.Credentials.from_info")
180+
@patch("litellm.llms.vertex_ai.vertex_llm_base.identity_pool.Credentials.from_info")
181+
def test_load_auth_wif(self, mock_identity_pool, mock_aws):
182+
vertex_base = VertexBase()
183+
input_project_id = "some_project_id"
184+
185+
# Test case 1: Using Workload Identity Federation for Microsoft Azure and
186+
# OIDC identity providers (default behavior)
187+
json_obj_1 = {
188+
"type": "external_account",
189+
}
190+
mock_creds_1 = MagicMock()
191+
mock_identity_pool.return_value = mock_creds_1
192+
193+
creds_1, project_id= vertex_base.load_auth(
194+
credentials=json_obj_1, project_id=input_project_id
195+
)
196+
mock_identity_pool.assert_called_once_with(json_obj_1)
197+
assert creds_1 == mock_creds_1
198+
assert project_id == input_project_id
199+
200+
# Test case 2: Using Workload Identity Federation for AWS
201+
json_obj_2 = {
202+
"type": "external_account",
203+
"credential_source": {"environment_id": "aws1"}
204+
}
205+
mock_creds_2 = MagicMock()
206+
mock_aws.return_value = mock_creds_2
207+
208+
creds_2, project_id= vertex_base.load_auth(
209+
credentials=json_obj_2, project_id=input_project_id
210+
)
211+
mock_aws.assert_called_once_with(json_obj_2)
212+
assert creds_2 == mock_creds_2
213+
assert project_id == input_project_id

0 commit comments

Comments
 (0)