Skip to content

Commit af41bad

Browse files
authored
Use role credentials when constructing clients for contract tests and use regional sts endpoint when getting account id (#548)
* Use role credentials when constructing clients for contract tests and use regional sts endpoint when getting account id * Don't use default credentials chain for boto helper get account * Use black to format files
1 parent 1492dbc commit af41bad

File tree

11 files changed

+166
-83
lines changed

11 files changed

+166
-83
lines changed

src/rpdk/core/boto_helpers.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,16 @@ def get_temporary_credentials(session, key_names=BOTO_CRED_KEYS, role_arn=None):
4646
response = sts_client.assume_role(
4747
RoleArn=role_arn, RoleSessionName=session_name
4848
)
49-
except ClientError as e:
49+
except ClientError:
50+
# pylint: disable=W1201
5051
LOG.debug(
51-
"Getting session token resulted in unknown ClientError", exc_info=e
52+
"Getting session token resulted in unknown ClientError. "
53+
+ "Could not assume specified role '%s'.",
54+
role_arn,
55+
)
56+
raise DownstreamError() from Exception(
57+
"Could not assume specified role '{}'".format(role_arn)
5258
)
53-
raise DownstreamError("Could not assume specified role") from e
5459
temp = response["Credentials"]
5560
creds = (temp["AccessKeyId"], temp["SecretAccessKey"], temp["SessionToken"])
5661
else:
@@ -78,7 +83,14 @@ def get_service_endpoint(service, region):
7883
return "https://" + endpoint_data["hostname"]
7984

8085

81-
def get_account(session):
82-
sts_client = session.client("sts")
86+
def get_account(session, temporary_credentials):
87+
sts_client = session.client(
88+
"sts",
89+
endpoint_url=get_service_endpoint("sts", session.region_name),
90+
region_name=session.region_name,
91+
aws_access_key_id=temporary_credentials["accessKeyId"],
92+
aws_secret_access_key=temporary_credentials["secretAccessKey"],
93+
aws_session_token=temporary_credentials["sessionToken"],
94+
)
8395
response = sts_client.get_caller_identity()
8496
return response.get("Account")

src/rpdk/core/contract/resource_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(
132132
self._session, LOWER_CAMEL_CRED_KEYS, role_arn
133133
)
134134
self.region = region
135-
self.account = get_account(self._session)
135+
self.account = get_account(self._session, self._creds)
136136
self.partition = self._get_partition()
137137
self._function_name = function_name
138138
if endpoint.startswith("http://"):

src/rpdk/core/contract/suite/handler_commons.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def test_create_failure_if_repeat_writeable_id(resource_client, current_resource
5050
@response_contains_resource_model_equal_current_model
5151
def test_read_success(resource_client, current_resource_model):
5252
primay_identifier_only_model = create_model_with_properties_in_path(
53-
current_resource_model.copy(), resource_client.primary_identifier_paths,
53+
current_resource_model.copy(),
54+
resource_client.primary_identifier_paths,
5455
)
5556
_status, response, _error_code = resource_client.call_and_assert(
5657
Action.READ, OperationStatus.SUCCESS, primay_identifier_only_model
@@ -61,7 +62,8 @@ def test_read_success(resource_client, current_resource_model):
6162
@failed_event(error_code=HandlerErrorCode.NotFound)
6263
def test_read_failure_not_found(resource_client, current_resource_model):
6364
primay_identifier_only_model = create_model_with_properties_in_path(
64-
current_resource_model, resource_client.primary_identifier_paths,
65+
current_resource_model,
66+
resource_client.primary_identifier_paths,
6567
)
6668
_status, _response, error_code = resource_client.call_and_assert(
6769
Action.READ, OperationStatus.FAILED, primay_identifier_only_model
@@ -124,7 +126,8 @@ def test_update_failure_not_found(resource_client, current_resource_model):
124126

125127
def test_delete_success(resource_client, current_resource_model):
126128
primay_identifier_only_model = create_model_with_properties_in_path(
127-
current_resource_model, resource_client.primary_identifier_paths,
129+
current_resource_model,
130+
resource_client.primary_identifier_paths,
128131
)
129132
_status, response, _error_code = resource_client.call_and_assert(
130133
Action.DELETE, OperationStatus.SUCCESS, primay_identifier_only_model
@@ -135,7 +138,8 @@ def test_delete_success(resource_client, current_resource_model):
135138
@failed_event(error_code=HandlerErrorCode.NotFound)
136139
def test_delete_failure_not_found(resource_client, current_resource_model):
137140
primay_identifier_only_model = create_model_with_properties_in_path(
138-
current_resource_model, resource_client.primary_identifier_paths,
141+
current_resource_model,
142+
resource_client.primary_identifier_paths,
139143
)
140144
_status, _response, error_code = resource_client.call_and_assert(
141145
Action.DELETE, OperationStatus.FAILED, primay_identifier_only_model

src/rpdk/core/contract/suite/handler_delete.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def deleted_resource(resource_client):
3636
model = response["resourceModel"]
3737
test_input_equals_output(resource_client, input_model, model)
3838
primay_identifier_only_model = create_model_with_properties_in_path(
39-
model, resource_client.primary_identifier_paths,
39+
model,
40+
resource_client.primary_identifier_paths,
4041
)
4142
_status, response, _error = resource_client.call_and_assert(
4243
Action.DELETE, OperationStatus.SUCCESS, primay_identifier_only_model

src/rpdk/core/contract/suite/handler_update_invalid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def contract_update_create_only_property(resource_client):
3030
assert response["message"]
3131
finally:
3232
primay_identifier_only_model = create_model_with_properties_in_path(
33-
created_model, resource_client.primary_identifier_paths,
33+
created_model,
34+
resource_client.primary_identifier_paths,
3435
)
3536
resource_client.call_and_assert(
3637
Action.DELETE, OperationStatus.SUCCESS, primay_identifier_only_model

src/rpdk/core/jsonutils/pointer.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,21 @@ def fragment_decode(pointer, prefix="#", output=tuple):
9999

100100
def fragment_list(segments, prefix="properties", output=list):
101101
"""Decode all segments of a JSON pointer from the URI fragment
102-
identifier representation.
102+
identifier representation.
103103
104-
>>> fragment_list(["properties"])
105-
[]
106-
>>> fragment_list(["properties", "foo", "bar"])
107-
['foo', 'bar']
108-
>>> fragment_list(["properties", "foo", "bar"], output=tuple)
109-
('foo', 'bar')
110-
>>> fragment_list(["properties", "0", "%20", "~0"])
111-
['0', ' ', '~']
112-
>>> fragment_list(["foo"])
113-
Traceback (most recent call last):
114-
...
115-
ValueError: Expected prefix 'properties', but was 'foo'
116-
"""
104+
>>> fragment_list(["properties"])
105+
[]
106+
>>> fragment_list(["properties", "foo", "bar"])
107+
['foo', 'bar']
108+
>>> fragment_list(["properties", "foo", "bar"], output=tuple)
109+
('foo', 'bar')
110+
>>> fragment_list(["properties", "0", "%20", "~0"])
111+
['0', ' ', '~']
112+
>>> fragment_list(["foo"])
113+
Traceback (most recent call last):
114+
...
115+
ValueError: Expected prefix 'properties', but was 'foo'
116+
"""
117117
decoded = (part_decode(unquote(segment)) for segment in segments)
118118
actual = next(decoded)
119119
if prefix != actual:

src/rpdk/core/test.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from rpdk.core.jsonutils.pointer import fragment_decode
1919

20-
from .boto_helpers import create_sdk_session
20+
from .boto_helpers import create_sdk_session, get_temporary_credentials
2121
from .contract.contract_plugin import ContractPlugin
2222
from .contract.interface import Action
2323
from .contract.resource_client import ResourceClient
@@ -57,9 +57,12 @@ def temporary_ini_file():
5757
yield str(path)
5858

5959

60-
def get_cloudformation_exports(region_name, endpoint_url):
60+
def get_cloudformation_exports(region_name, endpoint_url, role_arn):
6161
session = create_sdk_session(region_name)
62-
cfn_client = session.client("cloudformation", endpoint_url=endpoint_url)
62+
temp_credentials = get_temporary_credentials(session, role_arn=role_arn)
63+
cfn_client = session.client(
64+
"cloudformation", endpoint_url=endpoint_url, **temp_credentials
65+
)
6366
paginator = cfn_client.get_paginator("list_exports")
6467
pages = paginator.paginate()
6568
exports = {}
@@ -68,12 +71,12 @@ def get_cloudformation_exports(region_name, endpoint_url):
6871
return exports
6972

7073

71-
def render_jinja(overrides_string, region_name, endpoint_url):
74+
def render_jinja(overrides_string, region_name, endpoint_url, role_arn):
7275
env = Environment(autoescape=True)
7376
parsed_content = env.parse(overrides_string)
7477
variables = meta.find_undeclared_variables(parsed_content)
7578
if variables:
76-
exports = get_cloudformation_exports(region_name, endpoint_url)
79+
exports = get_cloudformation_exports(region_name, endpoint_url, role_arn)
7780
invalid_exports = variables - exports.keys()
7881
if len(invalid_exports) > 0:
7982
invalid_exports_message = (
@@ -89,14 +92,14 @@ def render_jinja(overrides_string, region_name, endpoint_url):
8992
return to_return
9093

9194

92-
def get_overrides(root, region_name, endpoint_url):
95+
def get_overrides(root, region_name, endpoint_url, role_arn):
9396
if not root:
9497
return empty_override()
9598

9699
path = root / "overrides.json"
97100
try:
98101
with path.open("r", encoding="utf-8") as f:
99-
overrides_raw = render_jinja(f.read(), region_name, endpoint_url)
102+
overrides_raw = render_jinja(f.read(), region_name, endpoint_url, role_arn)
100103
except FileNotFoundError:
101104
LOG.debug("Override file '%s' not found. No overrides will be applied", path)
102105
return empty_override()
@@ -123,7 +126,7 @@ def get_overrides(root, region_name, endpoint_url):
123126

124127

125128
# pylint: disable=R0914
126-
def get_inputs(root, region_name, endpoint_url, value):
129+
def get_inputs(root, region_name, endpoint_url, value, role_arn):
127130
inputs = {}
128131
if not root:
129132
return None
@@ -144,7 +147,9 @@ def get_inputs(root, region_name, endpoint_url, value):
144147

145148
file_path = path / file
146149
with file_path.open("r", encoding="utf-8") as f:
147-
overrides_raw = render_jinja(f.read(), region_name, endpoint_url)
150+
overrides_raw = render_jinja(
151+
f.read(), region_name, endpoint_url, role_arn
152+
)
148153
overrides = {}
149154
for pointer, obj in overrides_raw.items():
150155
overrides[pointer] = obj
@@ -175,13 +180,17 @@ def test(args):
175180
project.load()
176181

177182
overrides = get_overrides(
178-
project.root, args.region, args.cloudformation_endpoint_url
183+
project.root, args.region, args.cloudformation_endpoint_url, args.role_arn
179184
)
180185

181186
index = 1
182187
while True:
183188
inputs = get_inputs(
184-
project.root, args.region, args.cloudformation_endpoint_url, index
189+
project.root,
190+
args.region,
191+
args.cloudformation_endpoint_url,
192+
index,
193+
args.role_arn,
185194
)
186195
if not inputs:
187196
break

tests/contract/test_resource_client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def resource_client():
8080

8181
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
8282
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
83-
mock_account.assert_called_once_with(mock_sesh)
83+
mock_account.assert_called_once_with(mock_sesh, {})
8484
assert client._creds == {}
8585
assert client._function_name == DEFAULT_FUNCTION
8686
assert client._schema == {}
@@ -121,7 +121,7 @@ def resource_client_inputs():
121121

122122
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
123123
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
124-
mock_account.assert_called_once_with(mock_sesh)
124+
mock_account.assert_called_once_with(mock_sesh, {})
125125

126126
assert client._creds == {}
127127
assert client._function_name == DEFAULT_FUNCTION
@@ -213,7 +213,9 @@ def test_init_sam_cli_client():
213213
"rpdk.core.contract.resource_client.create_sdk_session", autospec=True
214214
)
215215
patch_creds = patch(
216-
"rpdk.core.contract.resource_client.get_temporary_credentials", autospec=True
216+
"rpdk.core.contract.resource_client.get_temporary_credentials",
217+
autospec=True,
218+
return_value={},
217219
)
218220
patch_account = patch(
219221
"rpdk.core.contract.resource_client.get_account",
@@ -232,7 +234,7 @@ def test_init_sam_cli_client():
232234
"lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY
233235
)
234236
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
235-
mock_account.assert_called_once_with(mock_sesh)
237+
mock_account.assert_called_once_with(mock_sesh, {})
236238
assert client.account == ACCOUNT
237239

238240

tests/test_boto_helpers.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,26 @@ def test_get_temporary_credentials_assume_role():
200200
assert tuple(creds.values()) == (access_key, secret_key, token)
201201

202202

203-
def test_get_account():
203+
def test_get_account_with_temporary_credentials():
204204
session = create_autospec(spec=Session, spec_set=True)
205205
client = session.client.return_value
206-
get_account(session)
206+
session.region_name = "us-east-1"
207+
access_key = object()
208+
secret_key = object()
209+
token = object()
210+
creds = {
211+
"accessKeyId": access_key,
212+
"secretAccessKey": secret_key,
213+
"sessionToken": token,
214+
}
215+
get_account(session, creds)
207216

208-
session.client.assert_called_once_with("sts")
217+
session.client.assert_called_once_with(
218+
"sts",
219+
aws_access_key_id=access_key,
220+
aws_secret_access_key=secret_key,
221+
aws_session_token=token,
222+
endpoint_url="https://sts.us-east-1.amazonaws.com",
223+
region_name="us-east-1",
224+
)
209225
client.get_caller_identity.assert_called_once()

0 commit comments

Comments
 (0)