Skip to content

Commit d6846bd

Browse files
committed
add missing type hints to util code
1 parent 102b292 commit d6846bd

File tree

5 files changed

+89
-63
lines changed

5 files changed

+89
-63
lines changed

policy_sentry/util/arns.py

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,29 @@
88
# Case 5: arn:partition:service:region:account-id:resourcetype:resource
99
# Case 6: arn:partition:service:region:account-id:resourcetype:resource:qualifier
1010
# Source: https://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-arns
11+
from __future__ import annotations
12+
1113
import logging
1214
import re
1315

16+
ARN_SEPARATOR_PATTERN = re.compile(r"[:/]")
17+
# Note: Each service can only have one of these, so these are definitely exceptions
18+
EXCLUSION_LIST = {
19+
"${ObjectName}",
20+
"${RepositoryName}",
21+
"${BucketName}",
22+
"table/${TableName}",
23+
"${BucketName}/${ObjectName}",
24+
}
25+
1426
logger = logging.getLogger(__name__)
1527

1628

1729
# pylint: disable=too-many-instance-attributes
1830
class ARN:
1931
"""Class that helps to match ARN resource type formats neatly"""
2032

21-
def __init__(self, provided_arn):
33+
def __init__(self, provided_arn: str) -> None:
2234
self.arn = provided_arn
2335
follows_arn_format = re.search(
2436
r"^arn:([^:]*):([^:]*):([^:]*):([^:]*):(.+)$", provided_arn
@@ -34,18 +46,20 @@ def __init__(self, provided_arn):
3446
self.account = elements[4]
3547
self.resource = elements[5]
3648
except IndexError as error:
37-
raise Exception(f"The provided ARN is invalid. IndexError: {error}. Please provide a valid ARN.") from error
49+
raise Exception(
50+
f"The provided ARN is invalid. IndexError: {error}. Please provide a valid ARN."
51+
) from error
3852
if "/" in self.resource:
3953
self.resource, self.resource_path = self.resource.split("/", 1)
4054
elif ":" in self.resource:
4155
self.resource, self.resource_path = self.resource.split(":", 1)
4256
self.resource_string = self._resource_string()
4357

44-
def __repr__(self):
58+
def __repr__(self) -> str:
4559
return self.arn
4660

4761
# pylint: disable=too-many-return-statements
48-
def _resource_string(self):
62+
def _resource_string(self) -> str:
4963
"""
5064
Given an ARN, return the string after the account ID, no matter the ARN format.
5165
Return:
@@ -62,7 +76,7 @@ def _resource_string(self):
6276
resource_string = ":".join(split_arn[5:])
6377
return resource_string
6478

65-
def same_resource_type(self, arn_in_database):
79+
def same_resource_type(self, arn_in_database: str) -> bool:
6680
"""Given an arn, see if it has the same resource type"""
6781

6882
# 1. If the RAW ARN in the database is *, then it doesn't have a resource type
@@ -80,14 +94,16 @@ def same_resource_type(self, arn_in_database):
8094
# Previously, this would fail and return empty results.
8195
# Now it correctly returns the full list of matching ARNs and corresponding actions.
8296
resource_type_arn_to_test = parse_arn_for_resource_type(self.arn)
83-
if resource_type_arn_to_test == '*':
97+
if resource_type_arn_to_test == "*":
8498
return True
8599

86100
# 4. Match patterns for complicated resource strings, leveraging the standardized format of the Raw ARN format
87101
# table/${TableName} should not match `table/${TableName}/backup/${BackupName}`
88102
resource_string_arn_in_database = get_resource_string(arn_in_database)
89103

90-
split_resource_string_in_database = re.split(':|/', resource_string_arn_in_database)
104+
split_resource_string_in_database = re.split(
105+
ARN_SEPARATOR_PATTERN, resource_string_arn_in_database
106+
)
91107
# logger.debug(str(split_resource_string_in_database))
92108
arn_format_list = []
93109
for elem in split_resource_string_in_database:
@@ -97,47 +113,47 @@ def same_resource_type(self, arn_in_database):
97113
# If an element says something like ${TableName}, normalize it to an empty string
98114
arn_format_list.append("")
99115

100-
split_resource_string_to_test = re.split(':|/', self.resource_string)
116+
split_resource_string_to_test = re.split(
117+
ARN_SEPARATOR_PATTERN, self.resource_string
118+
)
101119
# 4b: If we have a confusing resource string, the length of the split resource string list
102120
# should at least be the same
103121
# Again, table/${TableName} (len of 2) should not match `table/${TableName}/backup/${BackupName}` (len of 4)
104122
# if len(split_resource_string_to_test) != len(arn_format_list):
105123
# return False
106124

107-
non_empty_arn_format_list = []
108-
for i in arn_format_list:
109-
if i != "":
110-
non_empty_arn_format_list.append(i)
111-
112-
lower_resource_string = list(map(lambda x:x.lower(),split_resource_string_to_test))
113-
for i in non_empty_arn_format_list:
114-
if i.lower() not in lower_resource_string:
125+
lower_resource_string = [x.lower() for x in split_resource_string_to_test]
126+
for elem in arn_format_list:
127+
if elem and elem.lower() not in lower_resource_string:
115128
return False
116129

117130
# 4c: See if the non-normalized fields match
118-
for i in range(len(arn_format_list)):
131+
for idx, elem in enumerate(arn_format_list):
119132
# If the field is not normalized to empty string, then make sure the resource type segments match
120133
# So, using table/${TableName}/backup/${BackupName} as an example:
121134
# table should match, backup should match,
122135
# and length of the arn_format_list should be the same as split_resource_string_to_test
123136
# If all conditions match, then the ARN format is the same.
124-
if arn_format_list[i] != "":
125-
if arn_format_list[i] == split_resource_string_to_test[i]:
137+
if elem:
138+
if elem == split_resource_string_to_test[idx]:
126139
pass
127-
elif split_resource_string_to_test[i] == "*":
140+
elif split_resource_string_to_test[idx] == "*":
128141
pass
129142
else:
130143
return False
131144

132145
# 4. Special type for S3 bucket objects and CodeCommit repos
133-
# Note: Each service can only have one of these, so these are definitely exceptions
134-
exclusion_list = ["${ObjectName}", "${RepositoryName}", "${BucketName}", "table/${TableName}", "${BucketName}/${ObjectName}"]
135146
resource_path_arn_in_database = elements[5]
136-
if resource_path_arn_in_database in exclusion_list:
137-
logger.debug("Special type: %s", resource_path_arn_in_database)
147+
if resource_path_arn_in_database in EXCLUSION_LIST:
148+
logger.debug(f"Special type: {resource_path_arn_in_database}")
138149
# handling special case table/${TableName}
139-
if resource_string_arn_in_database in ["table/${TableName}", "${BucketName}"]:
140-
return len(self.resource_string.split('/')) == len(elements[5].split('/'))
150+
if resource_string_arn_in_database in (
151+
"table/${TableName}",
152+
"${BucketName}",
153+
):
154+
return len(self.resource_string.split("/")) == len(
155+
elements[5].split("/")
156+
)
141157
# If we've made it this far, then it is a special type
142158
# return True
143159
# Presence of / would mean it's an object in both so it matches
@@ -154,7 +170,7 @@ def same_resource_type(self, arn_in_database):
154170
return True
155171

156172

157-
def parse_arn(arn):
173+
def parse_arn(arn: str) -> dict[str, str]:
158174
"""
159175
Given an ARN, split up the ARN into the ARN namespacing schema dictated by the AWS docs.
160176
"""
@@ -167,53 +183,52 @@ def parse_arn(arn):
167183
"region": elements[3],
168184
"account": elements[4],
169185
"resource": elements[5],
170-
"resource_path": None,
186+
"resource_path": "",
171187
}
172188
except IndexError as error:
173-
raise Exception(f"IndexError: The provided ARN '{arn}' is invalid. Please provide a valid ARN.") from error
189+
raise Exception(
190+
f"IndexError: The provided ARN '{arn}' is invalid. Please provide a valid ARN."
191+
) from error
174192
if "/" in result["resource"]:
175193
result["resource"], result["resource_path"] = result["resource"].split("/", 1)
176194
elif ":" in result["resource"]:
177195
result["resource"], result["resource_path"] = result["resource"].split(":", 1)
178196
return result
179197

180198

181-
def get_service_from_arn(arn):
182-
"""Given an ARN string, return the service """
199+
def get_service_from_arn(arn: str) -> str:
200+
"""Given an ARN string, return the service"""
183201
result = parse_arn(arn)
184202
return result["service"]
185203

186204

187-
def get_region_from_arn(arn):
205+
def get_region_from_arn(arn: str) -> str:
188206
"""Given an ARN, return the region in the ARN, if it is available. In certain cases like S3 it is not"""
189207
result = parse_arn(arn)
190208
# Support S3 buckets with no values under region
191209
if result["region"] is None:
192-
result = ""
193-
else:
194-
result = result["region"]
195-
return result
210+
return ""
211+
return result["region"]
196212

197213

198-
def get_account_from_arn(arn):
214+
def get_account_from_arn(arn: str) -> str:
199215
"""Given an ARN, return the account ID in the ARN, if it is available. In certain cases like S3 it is not"""
200216
result = parse_arn(arn)
201217
# Support S3 buckets with no values under account
202218
if result["account"] is None:
203-
result = ""
204-
else:
205-
result = result["account"]
206-
return result
219+
return ""
220+
return result["account"]
207221

208222

209-
def get_resource_path_from_arn(arn):
223+
def get_resource_path_from_arn(arn: str) -> str | None:
210224
"""Given an ARN, parse it according to ARN namespacing and return the resource path. See
211-
http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more details on ARN namespacing."""
225+
http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more details on ARN namespacing.
226+
"""
212227
result = parse_arn(arn)
213228
return result["resource_path"]
214229

215230

216-
def get_resource_string(arn):
231+
def get_resource_string(arn: str) -> str:
217232
"""
218233
Given an ARN, return the string after the account ID, no matter the ARN format.
219234
@@ -229,7 +244,7 @@ def get_resource_string(arn):
229244

230245
# In the meantime, we have to skip this pylint check (consider this as tech debt)
231246
# pylint: disable=inconsistent-return-statements
232-
def parse_arn_for_resource_type(arn):
247+
def parse_arn_for_resource_type(arn: str) -> str | None:
233248
"""
234249
Parses the resource string (resourcetype/resource and other variants) and grab the resource type.
235250
@@ -240,15 +255,17 @@ def parse_arn_for_resource_type(arn):
240255
"""
241256
split_arn = arn.split(":")
242257
resource_string = ":".join(split_arn[5:])
243-
split_resource = re.split("/|:", resource_string)
258+
split_resource = re.split(ARN_SEPARATOR_PATTERN, resource_string)
244259
if len(split_resource) == 1:
245260
# logger.debug(f"split_resource length is 1: {str(split_resource)}")
246261
pass
247262
elif len(split_resource) > 1:
248263
return split_resource[0]
249264

265+
return None
266+
250267

251-
def does_arn_match(arn_to_test, arn_in_database):
268+
def does_arn_match(arn_to_test: str, arn_in_database: str) -> bool:
252269
"""
253270
Given two ARNs, determine if they have the same resource type.
254271

policy_sentry/util/conditions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_comma_separated_condition_keys(condition_keys: str) -> str:
4444

4545

4646
def is_condition_key_match(document_key: str, some_str: str) -> bool:
47-
""" Given a documented condition key and one from a policy, determine if they match
47+
"""Given a documented condition key and one from a policy, determine if they match
4848
Examples:
4949
- s3:prefix and s3:prefix obviously match
5050
- s3:ExistingObjectTag/<key> and s3:ExistingObjectTag/backup match

policy_sentry/util/policy_files.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
"""A few methods for parsing policies."""
2+
from __future__ import annotations
3+
24
import json
35
import logging
4-
from operator import itemgetter
6+
from pathlib import Path
7+
from typing import Any
8+
9+
510
from policy_sentry.querying.actions import get_action_data
611

712
logger = logging.getLogger(__name__)
813

914

10-
def get_actions_from_statement(statement):
15+
def get_actions_from_statement(statement: dict[str, Any]) -> list[str]:
1116
"""Given a statement dictionary, create a list of the actions"""
12-
actions_list = []
17+
actions_list: list[str] = []
1318
# We only want to evaluate policies that have Effect = "Allow"
1419
if statement.get("Effect") == "Deny":
1520
return actions_list
@@ -30,7 +35,7 @@ def get_actions_from_statement(statement):
3035

3136

3237
# pylint: disable=too-many-branches,too-many-statements
33-
def get_actions_from_policy(data):
38+
def get_actions_from_policy(data: dict[str, Any]) -> list[str]:
3439
"""Given a policy dictionary, create a list of the actions"""
3540
actions_list = []
3641
statement_clause = data.get("Statement")
@@ -43,22 +48,20 @@ def get_actions_from_policy(data):
4348
actions_list.extend(get_actions_from_statement(statement))
4449
else:
4550
logger.critical("Unknown error: The 'Statement' is neither a dict nor a list")
46-
actions_list = [x.lower() for x in actions_list]
4751

4852
new_actions_list = []
4953
for action in actions_list:
5054
service, action_name = action.split(":")
5155
action_data = get_action_data(service, action_name)
52-
if service in action_data:
53-
if action_data[service]:
54-
new_actions_list.append(action_data[service][0]["action"])
56+
if service in action_data and action_data[service]:
57+
new_actions_list.append(action_data[service][0]["action"])
5558

5659
new_actions_list.sort()
5760
return new_actions_list
5861

5962

6063
# pylint: disable=too-many-branches,too-many-statements
61-
def get_actions_from_json_policy_file(file):
64+
def get_actions_from_json_policy_file(file: str | Path) -> list[str]:
6265
"""
6366
read the json policy file and return a list of actions
6467
"""
@@ -77,17 +80,23 @@ def get_actions_from_json_policy_file(file):
7780
return actions_list
7881

7982

80-
def get_sid_names_from_policy(policy_json):
83+
def get_sid_names_from_policy(policy_json: dict[str, Any]) -> list[str]:
8184
"""
8285
Given a Policy JSON, get a list of the Statement IDs. This is helpful in unit tests.
8386
"""
84-
sid_names = list(map(itemgetter("Sid"), policy_json.get("Statement")))
87+
sid_names = [
88+
statement["Sid"]
89+
for statement in policy_json.get("Statement", [])
90+
if "Sid" in statement
91+
]
8592
return sid_names
8693

8794

88-
def get_statement_from_policy_using_sid(policy_json, sid):
95+
def get_statement_from_policy_using_sid(
96+
policy_json: dict[str, Any], sid: str
97+
) -> dict[str, Any] | None:
8998
"""
9099
Helper function to get a statement just by providing the policy JSON and the Statement ID
91100
"""
92-
res = next((sub for sub in policy_json["Statement"] if sub['Sid'] == sid), None)
101+
res = next((sub for sub in policy_json["Statement"] if sub.get("Sid") == sid), None)
93102
return res

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
[tool.mypy]
22
files = "policy_sentry"
33
strict = true
4+
pretty = true
45

56
exclude = [
67
'^policy_sentry/analysis',
78
'^policy_sentry/bin',
89
'^policy_sentry/command',
910
'^policy_sentry/querying',
1011
'^policy_sentry/shared',
11-
'^policy_sentry/util/(arns|policy_files)',
1212
'^policy_sentry/writing',
1313
]

tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def run_linter(c):
247247
def run_mypy(c):
248248
"""Type checking with `mypy`"""
249249
try:
250-
c.run('mypy policy_sentry/')
250+
c.run('mypy')
251251
except UnexpectedExit as u_e:
252252
logger.critical(f"FAIL! UnexpectedExit: {u_e}")
253253
sys.exit(1)

0 commit comments

Comments
 (0)