Skip to content

Commit 34ee3fe

Browse files
brianhmjbrianhmjungarithmetic1728
authored
feat: adding domain-wide delegation flow in impersonated credential (#1624)
* Adding a flow in impersonated credentials to check if a subject is specificed for domain-wide delegation auth. * Adding a flow in impersonated credentials to check if a subject is specificed for domain-wide delegation auth. * Minor fixes to dwd flow in impersonation * Adding a flow in impersonated credentials to check if a subject is specificed for domain-wide delegation auth. * deleted repeated * delete repeated code * Fixing where source credentials authentication header info is, and target scopes. * Formatted code to uniform standard * Fixing lint and coverage failures from kokoro tests --------- Co-authored-by: Brian Jung <brianhmj@google.com> Co-authored-by: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com>
1 parent 1972c7b commit 34ee3fe

File tree

3 files changed

+225
-1
lines changed

3 files changed

+225
-1
lines changed

google/auth/iam.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@
4848
+ "/serviceAccounts/{}:signBlob"
4949
)
5050

51+
_IAM_SIGNJWT_ENDPOINT = (
52+
"https://iamcredentials.googleapis.com/v1/projects/-"
53+
+ "/serviceAccounts/{}:signJwt"
54+
)
55+
5156
_IAM_IDTOKEN_ENDPOINT = (
5257
"https://iamcredentials.googleapis.com/v1/"
5358
+ "projects/-/serviceAccounts/{}:generateIdToken"

google/auth/impersonated_credentials.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,15 @@
3838
from google.auth import iam
3939
from google.auth import jwt
4040
from google.auth import metrics
41+
from google.oauth2 import _client
4142

4243

4344
_REFRESH_ERROR = "Unable to acquire impersonated credentials"
4445

4546
_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
4647

48+
_GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token"
49+
4750

4851
def _make_iam_token_request(
4952
request,
@@ -177,6 +180,7 @@ def __init__(
177180
target_principal,
178181
target_scopes,
179182
delegates=None,
183+
subject=None,
180184
lifetime=_DEFAULT_TOKEN_LIFETIME_SECS,
181185
quota_project_id=None,
182186
iam_endpoint_override=None,
@@ -204,9 +208,12 @@ def __init__(
204208
quota_project_id (Optional[str]): The project ID used for quota and billing.
205209
This project may be different from the project used to
206210
create the credentials.
207-
iam_endpoint_override (Optiona[str]): The full IAM endpoint override
211+
iam_endpoint_override (Optional[str]): The full IAM endpoint override
208212
with the target_principal embedded. This is useful when supporting
209213
impersonation with regional endpoints.
214+
subject (Optional[str]): sub field of a JWT. This field should only be set
215+
if you wish to impersonate as a user. This feature is useful when
216+
using domain wide delegation.
210217
"""
211218

212219
super(Credentials, self).__init__()
@@ -231,6 +238,7 @@ def __init__(
231238
self._target_principal = target_principal
232239
self._target_scopes = target_scopes
233240
self._delegates = delegates
241+
self._subject = subject
234242
self._lifetime = lifetime or _DEFAULT_TOKEN_LIFETIME_SECS
235243
self.token = None
236244
self.expiry = _helpers.utcnow()
@@ -275,6 +283,39 @@ def _update_token(self, request):
275283
# Apply the source credentials authentication info.
276284
self._source_credentials.apply(headers)
277285

286+
# If a subject is specified a domain-wide delegation auth-flow is initiated
287+
# to impersonate as the provided subject (user).
288+
if self._subject:
289+
if self.universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN:
290+
raise exceptions.GoogleAuthError(
291+
"Domain-wide delegation is not supported in universes other "
292+
+ "than googleapis.com"
293+
)
294+
295+
now = _helpers.utcnow()
296+
payload = {
297+
"iss": self._target_principal,
298+
"scope": _helpers.scopes_to_string(self._target_scopes or ()),
299+
"sub": self._subject,
300+
"aud": _GOOGLE_OAUTH2_TOKEN_ENDPOINT,
301+
"iat": _helpers.datetime_to_secs(now),
302+
"exp": _helpers.datetime_to_secs(now) + _DEFAULT_TOKEN_LIFETIME_SECS,
303+
}
304+
305+
assertion = _sign_jwt_request(
306+
request=request,
307+
principal=self._target_principal,
308+
headers=headers,
309+
payload=payload,
310+
delegates=self._delegates,
311+
)
312+
313+
self.token, self.expiry, _ = _client.jwt_grant(
314+
request, _GOOGLE_OAUTH2_TOKEN_ENDPOINT, assertion
315+
)
316+
317+
return
318+
278319
self.token, self.expiry = _make_iam_token_request(
279320
request=request,
280321
principal=self._target_principal,
@@ -478,3 +519,61 @@ def refresh(self, request):
478519
self.expiry = datetime.utcfromtimestamp(
479520
jwt.decode(id_token, verify=False)["exp"]
480521
)
522+
523+
524+
def _sign_jwt_request(request, principal, headers, payload, delegates=[]):
525+
"""Makes a request to the Google Cloud IAM service to sign a JWT using a
526+
service account's system-managed private key.
527+
Args:
528+
request (Request): The Request object to use.
529+
principal (str): The principal to request an access token for.
530+
headers (Mapping[str, str]): Map of headers to transmit.
531+
payload (Mapping[str, str]): The JWT payload to sign. Must be a
532+
serialized JSON object that contains a JWT Claims Set.
533+
delegates (Sequence[str]): The chained list of delegates required
534+
to grant the final access_token. If set, the sequence of
535+
identities must have "Service Account Token Creator" capability
536+
granted to the prceeding identity. For example, if set to
537+
[serviceAccountB, serviceAccountC], the source_credential
538+
must have the Token Creator role on serviceAccountB.
539+
serviceAccountB must have the Token Creator on
540+
serviceAccountC.
541+
Finally, C must have Token Creator on target_principal.
542+
If left unset, source_credential must have that role on
543+
target_principal.
544+
545+
Raises:
546+
google.auth.exceptions.TransportError: Raised if there is an underlying
547+
HTTP connection error
548+
google.auth.exceptions.RefreshError: Raised if the impersonated
549+
credentials are not available. Common reasons are
550+
`iamcredentials.googleapis.com` is not enabled or the
551+
`Service Account Token Creator` is not assigned
552+
"""
553+
iam_endpoint = iam._IAM_SIGNJWT_ENDPOINT.format(principal)
554+
555+
body = {"delegates": delegates, "payload": json.dumps(payload)}
556+
body = json.dumps(body).encode("utf-8")
557+
558+
response = request(url=iam_endpoint, method="POST", headers=headers, body=body)
559+
560+
# support both string and bytes type response.data
561+
response_body = (
562+
response.data.decode("utf-8")
563+
if hasattr(response.data, "decode")
564+
else response.data
565+
)
566+
567+
if response.status != http_client.OK:
568+
raise exceptions.RefreshError(_REFRESH_ERROR, response_body)
569+
570+
try:
571+
jwt_response = json.loads(response_body)
572+
signed_jwt = jwt_response["signedJwt"]
573+
return signed_jwt
574+
575+
except (KeyError, ValueError) as caught_exc:
576+
new_exc = exceptions.RefreshError(
577+
"{}: No signed JWT in response.".format(_REFRESH_ERROR), response_body
578+
)
579+
raise new_exc from caught_exc

tests/test_impersonated_credentials.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,17 @@ def mock_donor_credentials():
7171
yield grant
7272

7373

74+
@pytest.fixture
75+
def mock_dwd_credentials():
76+
with mock.patch("google.oauth2._client.jwt_grant", autospec=True) as grant:
77+
grant.return_value = (
78+
"1/fFAGRNJasdfz70BzhT3Zg",
79+
_helpers.utcnow() + datetime.timedelta(seconds=500),
80+
{},
81+
)
82+
yield grant
83+
84+
7485
class MockResponse:
7586
def __init__(self, json_data, status_code):
7687
self.json_data = json_data
@@ -123,6 +134,7 @@ def make_credentials(
123134
source_credentials=SOURCE_CREDENTIALS,
124135
lifetime=LIFETIME,
125136
target_principal=TARGET_PRINCIPAL,
137+
subject=None,
126138
iam_endpoint_override=None,
127139
):
128140

@@ -132,6 +144,7 @@ def make_credentials(
132144
target_scopes=self.TARGET_SCOPES,
133145
delegates=self.DELEGATES,
134146
lifetime=lifetime,
147+
subject=subject,
135148
iam_endpoint_override=iam_endpoint_override,
136149
)
137150

@@ -238,6 +251,28 @@ def test_refresh_success(self, use_data_bytes, mock_donor_credentials):
238251
== ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE
239252
)
240253

254+
@pytest.mark.parametrize("use_data_bytes", [True, False])
255+
def test_refresh_with_subject_success(self, use_data_bytes, mock_dwd_credentials):
256+
credentials = self.make_credentials(subject="test@email.com", lifetime=None)
257+
258+
response_body = {"signedJwt": "example_signed_jwt"}
259+
260+
request = self.make_request(
261+
data=json.dumps(response_body),
262+
status=http_client.OK,
263+
use_data_bytes=use_data_bytes,
264+
)
265+
266+
with mock.patch(
267+
"google.auth.metrics.token_request_access_token_impersonate",
268+
return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
269+
):
270+
credentials.refresh(request)
271+
272+
assert credentials.valid
273+
assert not credentials.expired
274+
assert credentials.token == "1/fFAGRNJasdfz70BzhT3Zg"
275+
241276
@pytest.mark.parametrize("use_data_bytes", [True, False])
242277
def test_refresh_success_nonGdu(self, use_data_bytes, mock_donor_credentials):
243278
source_credentials = service_account.Credentials(
@@ -418,6 +453,33 @@ def test_refresh_failure_http_error(self, mock_donor_credentials):
418453
assert not credentials.valid
419454
assert credentials.expired
420455

456+
def test_refresh_failure_subject_with_nondefault_domain(
457+
self, mock_donor_credentials
458+
):
459+
source_credentials = service_account.Credentials(
460+
SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar"
461+
)
462+
credentials = self.make_credentials(
463+
source_credentials=source_credentials, subject="test@email.com"
464+
)
465+
466+
expire_time = (_helpers.utcnow().replace(microsecond=0)).isoformat("T") + "Z"
467+
response_body = {"accessToken": "token", "expireTime": expire_time}
468+
request = self.make_request(
469+
data=json.dumps(response_body), status=http_client.OK
470+
)
471+
472+
with pytest.raises(exceptions.GoogleAuthError) as excinfo:
473+
credentials.refresh(request)
474+
475+
assert excinfo.match(
476+
"Domain-wide delegation is not supported in universes other "
477+
+ "than googleapis.com"
478+
)
479+
480+
assert not credentials.valid
481+
assert credentials.expired
482+
421483
def test_expired(self):
422484
credentials = self.make_credentials(lifetime=None)
423485
assert credentials.expired
@@ -810,3 +872,61 @@ def test_id_token_with_quota_project(
810872
id_creds.refresh(request)
811873

812874
assert id_creds.quota_project_id == "project-foo"
875+
876+
def test_sign_jwt_request_success(self):
877+
principal = "foo@example.com"
878+
expected_signed_jwt = "correct_signed_jwt"
879+
880+
response_body = {"keyId": "1", "signedJwt": expected_signed_jwt}
881+
request = self.make_request(
882+
data=json.dumps(response_body), status=http_client.OK
883+
)
884+
885+
signed_jwt = impersonated_credentials._sign_jwt_request(
886+
request=request, principal=principal, headers={}, payload={}
887+
)
888+
889+
assert signed_jwt == expected_signed_jwt
890+
request.assert_called_once_with(
891+
url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/foo@example.com:signJwt",
892+
method="POST",
893+
headers={},
894+
body=json.dumps({"delegates": [], "payload": json.dumps({})}).encode(
895+
"utf-8"
896+
),
897+
)
898+
899+
def test_sign_jwt_request_http_error(self):
900+
principal = "foo@example.com"
901+
902+
request = self.make_request(
903+
data="error_message", status=http_client.BAD_REQUEST
904+
)
905+
906+
with pytest.raises(exceptions.RefreshError) as excinfo:
907+
_ = impersonated_credentials._sign_jwt_request(
908+
request=request, principal=principal, headers={}, payload={}
909+
)
910+
911+
assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
912+
913+
assert excinfo.value.args[0] == "Unable to acquire impersonated credentials"
914+
assert excinfo.value.args[1] == "error_message"
915+
916+
def test_sign_jwt_request_invalid_response_error(self):
917+
principal = "foo@example.com"
918+
919+
request = self.make_request(data="invalid_data", status=http_client.OK)
920+
921+
with pytest.raises(exceptions.RefreshError) as excinfo:
922+
_ = impersonated_credentials._sign_jwt_request(
923+
request=request, principal=principal, headers={}, payload={}
924+
)
925+
926+
assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
927+
928+
assert (
929+
excinfo.value.args[0]
930+
== "Unable to acquire impersonated credentials: No signed JWT in response."
931+
)
932+
assert excinfo.value.args[1] == "invalid_data"

0 commit comments

Comments
 (0)