Skip to content

Commit aa6db3a

Browse files
authored
Adding headers in STS Calls for Confused Deputy (#1061)
1 parent b8ba39b commit aa6db3a

File tree

9 files changed

+187
-46
lines changed

9 files changed

+187
-46
lines changed

.pylintrc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ disable=
1212
duplicate-code, # finds dupes between tests and plugins
1313
too-few-public-methods, # triggers when inheriting
1414
ungrouped-imports, # clashes with isort
15+
W0613 # Unused argument 'kwargs'
1516

1617
[BASIC]
1718

@@ -23,4 +24,5 @@ indent-string=' '
2324
max-line-length=160
2425

2526
[DESIGN]
26-
max-locals=16
27+
max-locals=17
28+
max-args=6

src/rpdk/core/boto_helpers.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,35 @@ def _known_error(msg):
3232
return session
3333

3434

35-
def get_temporary_credentials(session, key_names=BOTO_CRED_KEYS, role_arn=None):
35+
def get_temporary_credentials(
36+
session, key_names=BOTO_CRED_KEYS, role_arn=None, headers=None
37+
):
3638
sts_client = session.client(
3739
"sts",
3840
endpoint_url=get_service_endpoint("sts", session.region_name),
3941
region_name=session.region_name,
4042
)
43+
check_keys = {"account_id", "source_arn"}
44+
if (
45+
headers
46+
and check_keys.issubset(headers.keys())
47+
and headers["account_id"]
48+
and headers["source_arn"]
49+
):
50+
# Inject headers through the event system.
51+
def inject_confused_deputy_headers(params, **kwargs):
52+
params["headers"]["x-amz-source-account"] = headers["account_id"]
53+
params["headers"]["x-amz-source-arn"] = headers["source_arn"]
54+
55+
sts_client.meta.events.register("before-call", inject_confused_deputy_headers)
56+
LOG.info(headers)
4157
if role_arn:
4258
session_name = f"CloudFormationContractTest-{datetime.now():%Y%m%d%H%M%S}"
4359
try:
4460
response = sts_client.assume_role(
45-
RoleArn=role_arn, RoleSessionName=session_name, DurationSeconds=900
61+
RoleArn=role_arn,
62+
RoleSessionName=session_name,
63+
DurationSeconds=900,
4664
)
4765
except ClientError:
4866
# pylint: disable=W1201

src/rpdk/core/contract/hook_client.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
type_name=None,
5757
log_group_name=None,
5858
log_role_arn=None,
59+
headers=None,
5960
docker_image=None,
6061
typeconfig=None,
6162
executable_entrypoint=None,
@@ -69,9 +70,12 @@ def __init__(
6970
self._log_group_name = log_group_name
7071
self._log_role_arn = log_role_arn
7172
self.region = region
73+
self._headers = headers
7274
self.account = get_account(
7375
self._session,
74-
get_temporary_credentials(self._session, LOWER_CAMEL_CRED_KEYS, role_arn),
76+
get_temporary_credentials(
77+
self._session, LOWER_CAMEL_CRED_KEYS, role_arn, headers
78+
),
7579
)
7680
self._function_name = function_name
7781
if endpoint.startswith("http://"):
@@ -396,11 +400,11 @@ def _make_payload(
396400
self.account,
397401
invocation_point,
398402
get_temporary_credentials(
399-
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn
403+
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers
400404
),
401405
self._log_group_name,
402406
get_temporary_credentials(
403-
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn
407+
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn, self._headers
404408
),
405409
self.generate_token(),
406410
target_model,

src/rpdk/core/contract/resource_client.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def __init__(
171171
type_name=None,
172172
log_group_name=None,
173173
log_role_arn=None,
174+
headers=None,
174175
docker_image=None,
175176
typeconfig=None,
176177
executable_entrypoint=None,
@@ -182,9 +183,12 @@ def __init__(
182183
self._log_group_name = log_group_name
183184
self._log_role_arn = log_role_arn
184185
self.region = region
186+
self._headers = headers
185187
self.account = get_account(
186188
self._session,
187-
get_temporary_credentials(self._session, LOWER_CAMEL_CRED_KEYS, role_arn),
189+
get_temporary_credentials(
190+
self._session, LOWER_CAMEL_CRED_KEYS, role_arn, headers
191+
),
188192
)
189193
self._function_name = function_name
190194
if endpoint.startswith("http://"):
@@ -674,12 +678,12 @@ def _make_payload(
674678
self.account,
675679
action,
676680
get_temporary_credentials(
677-
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn
681+
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers
678682
),
679683
self._type_name,
680684
self._log_group_name,
681685
get_temporary_credentials(
682-
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn
686+
self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn, self._headers
683687
),
684688
self.generate_token(),
685689
type_configuration=type_configuration,
@@ -794,7 +798,7 @@ def call(self, action, current_model, previous_model=None, **kwargs):
794798
request["callbackContext"] = response.get("callbackContext")
795799
# refresh credential for every handler invocation
796800
request["requestData"]["callerCredentials"] = get_temporary_credentials(
797-
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn
801+
self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn, self._headers
798802
)
799803

800804
response = self._call(request)

src/rpdk/core/test.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,13 @@ def temporary_ini_file():
102102
yield str(path)
103103

104104

105-
def get_cloudformation_exports(region_name, endpoint_url, role_arn, profile_name):
105+
def get_cloudformation_exports(
106+
region_name, endpoint_url, role_arn, profile_name, headers
107+
):
106108
session = create_sdk_session(region_name, profile_name)
107-
temp_credentials = get_temporary_credentials(session, role_arn=role_arn)
109+
temp_credentials = get_temporary_credentials(
110+
session, role_arn=role_arn, headers=headers
111+
)
108112
cfn_client = session.client(
109113
"cloudformation", endpoint_url=endpoint_url, **temp_credentials
110114
)
@@ -132,13 +136,13 @@ def __retrieve_args(match):
132136

133137

134138
def render_template(
135-
overrides_string, region_name, endpoint_url, role_arn, profile_name
139+
overrides_string, region_name, endpoint_url, role_arn, profile_name, headers
136140
):
137141
regex = r"{{([-A-Za-z0-9:\s]+?)}}"
138142
variables = set(str(match).strip() for match in re.findall(regex, overrides_string))
139143
if variables:
140144
exports = get_cloudformation_exports(
141-
region_name, endpoint_url, role_arn, profile_name
145+
region_name, endpoint_url, role_arn, profile_name, headers
142146
)
143147
invalid_exports = variables - exports.keys()
144148
if len(invalid_exports) > 0:
@@ -166,15 +170,20 @@ def filter_overrides(overrides, project):
166170
return overrides
167171

168172

169-
def get_overrides(root, region_name, endpoint_url, role_arn, profile_name):
173+
def get_overrides(root, region_name, endpoint_url, role_arn, profile_name, headers):
170174
if not root:
171175
return empty_override()
172176

173177
path = root / "overrides.json"
174178
try:
175179
with path.open("r", encoding="utf-8") as f:
176180
overrides_raw = render_template(
177-
f.read(), region_name, endpoint_url, role_arn, profile_name
181+
f.read(),
182+
region_name,
183+
endpoint_url,
184+
role_arn,
185+
profile_name,
186+
headers=headers,
178187
)
179188
except FileNotFoundError:
180189
LOG.debug("Override file '%s' not found. No overrides will be applied", path)
@@ -203,15 +212,22 @@ def get_overrides(root, region_name, endpoint_url, role_arn, profile_name):
203212

204213
# pylint: disable=R0914
205214
# flake8: noqa: C901
206-
def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name):
215+
def get_hook_overrides(
216+
root, region_name, endpoint_url, role_arn, profile_name, headers
217+
):
207218
if not root:
208219
return empty_hook_override()
209220

210221
path = root / "overrides.json"
211222
try:
212223
with path.open("r", encoding="utf-8") as f:
213224
overrides_raw = render_template(
214-
f.read(), region_name, endpoint_url, role_arn, profile_name
225+
f.read(),
226+
region_name,
227+
endpoint_url,
228+
role_arn,
229+
profile_name,
230+
headers=headers,
215231
)
216232
except FileNotFoundError:
217233
LOG.debug("Override file '%s' not found. No overrides will be applied", path)
@@ -258,7 +274,7 @@ def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name):
258274

259275

260276
# pylint: disable=R0914,too-many-arguments
261-
def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name):
277+
def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name, headers):
262278
inputs = {}
263279
if not root:
264280
return None
@@ -280,7 +296,12 @@ def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name):
280296
file_path = path / file
281297
with file_path.open("r", encoding="utf-8") as f:
282298
overrides_raw = render_template(
283-
f.read(), region_name, endpoint_url, role_arn, profile_name
299+
f.read(),
300+
region_name,
301+
endpoint_url,
302+
role_arn,
303+
profile_name,
304+
headers=headers,
284305
)
285306
overrides = {}
286307
for pointer, obj in overrides_raw.items():
@@ -355,6 +376,7 @@ def get_contract_plugin_client(args, project, overrides, inputs):
355376
project.type_name,
356377
args.log_group_name,
357378
args.log_role_arn,
379+
headers={"account_id": args.source_account, "source_arn": args.source_arn},
358380
executable_entrypoint=project.executable_entrypoint,
359381
docker_image=args.docker_image,
360382
typeconfig=args.typeconfig,
@@ -378,6 +400,7 @@ def get_contract_plugin_client(args, project, overrides, inputs):
378400
project.type_name,
379401
args.log_group_name,
380402
args.log_role_arn,
403+
headers={"account_id": args.source_account, "source_arn": args.source_arn},
381404
typeconfig=args.typeconfig,
382405
executable_entrypoint=project.executable_entrypoint,
383406
docker_image=args.docker_image,
@@ -402,6 +425,7 @@ def test(args):
402425
args.cloudformation_endpoint_url,
403426
args.role_arn,
404427
args.profile,
428+
headers={"account_id": args.source_account, "source_arn": args.source_arn},
405429
)
406430
else:
407431
overrides = get_overrides(
@@ -410,6 +434,7 @@ def test(args):
410434
args.cloudformation_endpoint_url,
411435
args.role_arn,
412436
args.profile,
437+
headers={"account_id": args.source_account, "source_arn": args.source_arn},
413438
)
414439
filter_overrides(overrides, project)
415440

@@ -422,6 +447,7 @@ def test(args):
422447
index,
423448
args.role_arn,
424449
args.profile,
450+
headers={"account_id": args.source_account, "source_arn": args.source_arn},
425451
)
426452
if not inputs:
427453
break
@@ -509,6 +535,15 @@ def setup_subparser(subparsers, parents):
509535
" '~/.cfn-cli/typeConfiguration.json.'"
510536
),
511537
)
538+
parser.add_argument(
539+
"--source-account",
540+
help="Source Account key used for Assume Role to Run Contract Tests",
541+
)
542+
543+
parser.add_argument(
544+
"--source-arn",
545+
help="Source Type Version Arn key used for Assume Role to Run Contract Tests",
546+
)
512547

513548

514549
def _sam_arguments(parser):

tests/contract/test_hook_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def hook_client():
121121
)
122122

123123
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
124-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
124+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
125125
mock_account.assert_called_once_with(mock_sesh, {})
126126
assert client._function_name == DEFAULT_FUNCTION
127127
assert client._schema == SCHEMA_
@@ -179,7 +179,7 @@ def hook_client_inputs():
179179
)
180180

181181
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
182-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
182+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
183183
mock_account.assert_called_once_with(mock_sesh, {})
184184
assert client._function_name == DEFAULT_FUNCTION
185185
assert client._schema == SCHEMA_
@@ -215,7 +215,7 @@ def test_init_sam_cli_client():
215215
mock_sesh.client.assert_called_once_with(
216216
"lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY
217217
)
218-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
218+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
219219
mock_account.assert_called_once_with(mock_sesh, {})
220220
assert client.account == ACCOUNT
221221

tests/contract/test_resource_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def resource_client():
179179
)
180180

181181
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
182-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
182+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
183183
mock_account.assert_called_once_with(mock_sesh, {})
184184
assert client._function_name == DEFAULT_FUNCTION
185185
assert client._schema == EMPTY_SCHEMA
@@ -214,7 +214,7 @@ def resource_client_no_handler():
214214
)
215215

216216
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
217-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
217+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
218218
mock_account.assert_called_once_with(mock_sesh, {})
219219
assert client._function_name == DEFAULT_FUNCTION
220220
assert client._schema == {}
@@ -254,7 +254,7 @@ def resource_client_inputs():
254254
)
255255

256256
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
257-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
257+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
258258
mock_account.assert_called_once_with(mock_sesh, {})
259259

260260
assert client._function_name == DEFAULT_FUNCTION
@@ -299,7 +299,7 @@ def resource_client_inputs_schema(request):
299299
)
300300

301301
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
302-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
302+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
303303
mock_account.assert_called_once_with(mock_sesh, {})
304304

305305
assert client._function_name == DEFAULT_FUNCTION
@@ -344,7 +344,7 @@ def resource_client_inputs_composite_key():
344344
)
345345

346346
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
347-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
347+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
348348
mock_account.assert_called_once_with(mock_sesh, {})
349349

350350
assert client._function_name == DEFAULT_FUNCTION
@@ -384,7 +384,7 @@ def resource_client_inputs_property_transform():
384384
)
385385

386386
mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint)
387-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
387+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
388388
mock_account.assert_called_once_with(mock_sesh, {})
389389
assert client._function_name == DEFAULT_FUNCTION
390390
assert client._schema == SCHEMA_WITH_PROPERTY_TRANSFORM
@@ -693,7 +693,7 @@ def test_init_sam_cli_client():
693693
mock_sesh.client.assert_called_once_with(
694694
"lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY
695695
)
696-
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None)
696+
mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None, None)
697697
mock_account.assert_called_once_with(mock_sesh, {})
698698
assert client.account == ACCOUNT
699699

0 commit comments

Comments
 (0)