Skip to content

Commit ada7f9c

Browse files
committed
added lock file tamper detection
1 parent 3b20ccf commit ada7f9c

File tree

5 files changed

+23
-6
lines changed

5 files changed

+23
-6
lines changed

adk/ADK.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(self, apply_func, load_func=None, client=None, manifest_path="model
1313
:param apply_func: A required function that can have an arity of 1-2, depending on if loading occurs
1414
:param load_func: An optional supplier function used if load time events are required, has an arity of 0.
1515
:param client: A Algorithmia Client instance that might be user defined, and is used for interacting with a model manifest file; if defined.
16+
:param manifest_path: A development / testing facing variable used to set the name and path
1617
"""
1718
self.FIFO_PATH = "/tmp/algoout"
1819
apply_args, _, _, _, _, _, _ = inspect.getfullargspec(apply_func)

adk/manifest.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def initialize(self):
2828
expected_hash = required_file['md5_checksum']
2929
with self.client.file(required_file['data_api_path']).getFile() as f:
3030
local_data_path = f.name
31-
real_hash = md5(local_data_path)
31+
real_hash = md5_for_file(local_data_path)
3232
if real_hash != expected_hash and required_file['fail_on_tamper']:
3333
raise Exception("Model File Mismatch for " + name +
3434
"\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash)
@@ -57,7 +57,7 @@ def find_optional_model(self, model_name):
5757
expected_hash = model_info['md5_checksum']
5858
with self.client.file(model_info['data_api_path']).getFile() as f:
5959
local_data_path = f.name
60-
real_hash = md5(local_data_path)
60+
real_hash = md5_for_file(local_data_path)
6161
if real_hash != expected_hash and model_info['fail_on_tamper']:
6262
raise Exception("Model File Mismatch for " + model_name +
6363
"\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash)
@@ -70,14 +70,26 @@ def get_manifest(manifest_path):
7070
if os.path.exists(manifest_path):
7171
with open(manifest_path) as f:
7272
manifest_data = json.load(f)
73+
expected_lock_checksum = manifest_data.get('lock_checksum')
74+
del manifest_data['lock_checksum']
75+
detected_lock_checksum = md5_for_str(str(manifest_data))
76+
if expected_lock_checksum != detected_lock_checksum:
77+
raise Exception("Manifest Lockfile Tamper Detected; please use the CLI and 'algo compile' to rebuild your "
78+
"algorithm's lock file.")
7379
return manifest_data
7480
else:
7581
return None
7682

7783

78-
def md5(fname):
84+
def md5_for_file(fname):
7985
hash_md5 = hashlib.md5()
8086
with open(fname, "rb") as f:
8187
for chunk in iter(lambda: f.read(4096), b""):
8288
hash_md5.update(chunk)
8389
return str(hash_md5.hexdigest())
90+
91+
92+
def md5_for_str(content):
93+
hash_md5 = hashlib.md5()
94+
hash_md5.update(content.encode())
95+
return str(hash_md5.hexdigest())

tests/bad_model_manifest.json.lock renamed to tests/manifests/bad_model_manifest.json.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"algorithm_name": "test_algorithm",
33
"timestamp": "1632770803",
4+
"lock_checksum": "57dcf8ab156f0c86f2f275919dcbf090",
45
"required_models" : [
56
{ "name": "squeezenet",
67
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",

tests/good_model_manifest.json.lock renamed to tests/manifests/good_model_manifest.json.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"algorithm_name": "test_algorithm",
33
"timestamp": "1632770803",
4+
"lock_checksum": "fba73f61886f0921b47997057e853d36",
45
"required_models" : [
56
{ "name": "squeezenet",
67
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",

tests/test_adk_local.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def execute_example(self, input, apply, load=lambda: None):
2020
algo.init(input, pprint=lambda x: output.append(x))
2121
return output[0]
2222

23-
def execute_manifest_example(self, input, apply, load, manifest_path="good_model_manifest.json.lock"):
23+
def execute_manifest_example(self, input, apply, load, manifest_path="manifests/good_model_manifest.json.lock"):
2424
client = Algorithmia.client()
2525
algo = ADK(apply, load, manifest_path=manifest_path, client=client)
2626
output = []
@@ -127,7 +127,8 @@ def test_manifest_file_success(self):
127127
}
128128
actual_output = json.loads(self.execute_manifest_example(input, apply_successful_manifest_parsing,
129129
loading_with_manifest,
130-
manifest_path="tests/good_model_manifest.json.lock"))
130+
manifest_path="tests/manifests/good_model_manifest"
131+
".json.lock"))
131132
self.assertEqual(expected_output, actual_output)
132133

133134
def test_manifest_file_tampered(self):
@@ -140,7 +141,8 @@ def test_manifest_file_tampered(self):
140141

141142
actual_output = json.loads(self.execute_manifest_example(input, apply_successful_manifest_parsing,
142143
loading_with_manifest,
143-
manifest_path="tests/bad_model_manifest.json.lock"))
144+
manifest_path="tests/manifests/bad_model_manifest"
145+
".json.lock"))
144146
self.assertEqual(expected_output, actual_output)
145147

146148

0 commit comments

Comments
 (0)