@@ -71,6 +71,17 @@ def mock_donor_credentials():
71
71
yield grant
72
72
73
73
74
+ @pytest .fixture
75
+ def mock_dwd_credentials ():
76
+ with mock .patch ("google.oauth2._client.jwt_grant" , autospec = True ) as grant :
77
+ grant .return_value = (
78
+ "1/fFAGRNJasdfz70BzhT3Zg" ,
79
+ _helpers .utcnow () + datetime .timedelta (seconds = 500 ),
80
+ {},
81
+ )
82
+ yield grant
83
+
84
+
74
85
class MockResponse :
75
86
def __init__ (self , json_data , status_code ):
76
87
self .json_data = json_data
@@ -123,6 +134,7 @@ def make_credentials(
123
134
source_credentials = SOURCE_CREDENTIALS ,
124
135
lifetime = LIFETIME ,
125
136
target_principal = TARGET_PRINCIPAL ,
137
+ subject = None ,
126
138
iam_endpoint_override = None ,
127
139
):
128
140
@@ -132,6 +144,7 @@ def make_credentials(
132
144
target_scopes = self .TARGET_SCOPES ,
133
145
delegates = self .DELEGATES ,
134
146
lifetime = lifetime ,
147
+ subject = subject ,
135
148
iam_endpoint_override = iam_endpoint_override ,
136
149
)
137
150
@@ -238,6 +251,28 @@ def test_refresh_success(self, use_data_bytes, mock_donor_credentials):
238
251
== ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE
239
252
)
240
253
254
+ @pytest .mark .parametrize ("use_data_bytes" , [True , False ])
255
+ def test_refresh_with_subject_success (self , use_data_bytes , mock_dwd_credentials ):
256
+ credentials = self .make_credentials (subject = "test@email.com" , lifetime = None )
257
+
258
+ response_body = {"signedJwt" : "example_signed_jwt" }
259
+
260
+ request = self .make_request (
261
+ data = json .dumps (response_body ),
262
+ status = http_client .OK ,
263
+ use_data_bytes = use_data_bytes ,
264
+ )
265
+
266
+ with mock .patch (
267
+ "google.auth.metrics.token_request_access_token_impersonate" ,
268
+ return_value = ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE ,
269
+ ):
270
+ credentials .refresh (request )
271
+
272
+ assert credentials .valid
273
+ assert not credentials .expired
274
+ assert credentials .token == "1/fFAGRNJasdfz70BzhT3Zg"
275
+
241
276
@pytest .mark .parametrize ("use_data_bytes" , [True , False ])
242
277
def test_refresh_success_nonGdu (self , use_data_bytes , mock_donor_credentials ):
243
278
source_credentials = service_account .Credentials (
@@ -418,6 +453,33 @@ def test_refresh_failure_http_error(self, mock_donor_credentials):
418
453
assert not credentials .valid
419
454
assert credentials .expired
420
455
456
+ def test_refresh_failure_subject_with_nondefault_domain (
457
+ self , mock_donor_credentials
458
+ ):
459
+ source_credentials = service_account .Credentials (
460
+ SIGNER , "some@email.com" , TOKEN_URI , universe_domain = "foo.bar"
461
+ )
462
+ credentials = self .make_credentials (
463
+ source_credentials = source_credentials , subject = "test@email.com"
464
+ )
465
+
466
+ expire_time = (_helpers .utcnow ().replace (microsecond = 0 )).isoformat ("T" ) + "Z"
467
+ response_body = {"accessToken" : "token" , "expireTime" : expire_time }
468
+ request = self .make_request (
469
+ data = json .dumps (response_body ), status = http_client .OK
470
+ )
471
+
472
+ with pytest .raises (exceptions .GoogleAuthError ) as excinfo :
473
+ credentials .refresh (request )
474
+
475
+ assert excinfo .match (
476
+ "Domain-wide delegation is not supported in universes other "
477
+ + "than googleapis.com"
478
+ )
479
+
480
+ assert not credentials .valid
481
+ assert credentials .expired
482
+
421
483
def test_expired (self ):
422
484
credentials = self .make_credentials (lifetime = None )
423
485
assert credentials .expired
@@ -810,3 +872,61 @@ def test_id_token_with_quota_project(
810
872
id_creds .refresh (request )
811
873
812
874
assert id_creds .quota_project_id == "project-foo"
875
+
876
+ def test_sign_jwt_request_success (self ):
877
+ principal = "foo@example.com"
878
+ expected_signed_jwt = "correct_signed_jwt"
879
+
880
+ response_body = {"keyId" : "1" , "signedJwt" : expected_signed_jwt }
881
+ request = self .make_request (
882
+ data = json .dumps (response_body ), status = http_client .OK
883
+ )
884
+
885
+ signed_jwt = impersonated_credentials ._sign_jwt_request (
886
+ request = request , principal = principal , headers = {}, payload = {}
887
+ )
888
+
889
+ assert signed_jwt == expected_signed_jwt
890
+ request .assert_called_once_with (
891
+ url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/foo@example.com:signJwt" ,
892
+ method = "POST" ,
893
+ headers = {},
894
+ body = json .dumps ({"delegates" : [], "payload" : json .dumps ({})}).encode (
895
+ "utf-8"
896
+ ),
897
+ )
898
+
899
+ def test_sign_jwt_request_http_error (self ):
900
+ principal = "foo@example.com"
901
+
902
+ request = self .make_request (
903
+ data = "error_message" , status = http_client .BAD_REQUEST
904
+ )
905
+
906
+ with pytest .raises (exceptions .RefreshError ) as excinfo :
907
+ _ = impersonated_credentials ._sign_jwt_request (
908
+ request = request , principal = principal , headers = {}, payload = {}
909
+ )
910
+
911
+ assert excinfo .match (impersonated_credentials ._REFRESH_ERROR )
912
+
913
+ assert excinfo .value .args [0 ] == "Unable to acquire impersonated credentials"
914
+ assert excinfo .value .args [1 ] == "error_message"
915
+
916
+ def test_sign_jwt_request_invalid_response_error (self ):
917
+ principal = "foo@example.com"
918
+
919
+ request = self .make_request (data = "invalid_data" , status = http_client .OK )
920
+
921
+ with pytest .raises (exceptions .RefreshError ) as excinfo :
922
+ _ = impersonated_credentials ._sign_jwt_request (
923
+ request = request , principal = principal , headers = {}, payload = {}
924
+ )
925
+
926
+ assert excinfo .match (impersonated_credentials ._REFRESH_ERROR )
927
+
928
+ assert (
929
+ excinfo .value .args [0 ]
930
+ == "Unable to acquire impersonated credentials: No signed JWT in response."
931
+ )
932
+ assert excinfo .value .args [1 ] == "invalid_data"
0 commit comments