Skip to content

Commit 5487dbd

Browse files
committed
restructured to remove unnecessary user_data and state_data objects, making the modelData class be it's own dict
1 parent 769bdb1 commit 5487dbd

File tree

5 files changed

+37
-22
lines changed

5 files changed

+37
-22
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def load(modelData):
106106
# during runtime.
107107

108108
# Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
109-
modelData.user_data['payload'] = "Loading has been completed."
109+
modelData['payload'] = "Loading has been completed."
110110
return modelData
111111

112112

@@ -176,9 +176,9 @@ def infer_image(image_url, n, globals):
176176

177177
def load(modelData):
178178

179-
modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
180-
modelData.user_data["model"] = load_model(modelData.get_model("squeezenet"))
181-
modelData.user_data["labels"] = load_labels(modelData.get_model("labels"))
179+
modelData["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
180+
modelData["model"] = load_model(modelData.get_model("squeezenet"))
181+
modelData["labels"] = load_labels(modelData.get_model("labels"))
182182
return modelData
183183

184184

@@ -190,10 +190,10 @@ def apply(input, modelData):
190190
n = 3
191191
if "data" in input:
192192
if isinstance(input["data"], str):
193-
output = infer_image(input["data"], n, modelData.user_data)
193+
output = infer_image(input["data"], n, modelData)
194194
elif isinstance(input["data"], list):
195195
for row in input["data"]:
196-
row["predictions"] = infer_image(row["image_url"], n, modelData.user_data)
196+
row["predictions"] = infer_image(row["image_url"], n, modelData)
197197
output = input["data"]
198198
else:
199199
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
@@ -257,4 +257,4 @@ Verify that it works on pytest, then:
257257
```commandline
258258
python -m twine upload -r pypi dist/*
259259
```
260-
and you're done :)
260+
and you're done :)

adk/modeldata.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,23 @@ def __init__(self, client, model_manifest_path):
1010
self.manifest_data = get_manifest(self.manifest_freeze_path)
1111
self.client = client
1212
self.models = {}
13-
self.user_data = {}
14-
self.system_data = {}
13+
self.usr_key = "__user__"
14+
15+
def __getitem__(self, key):
16+
return getattr(self, self.usr_key + key)
17+
18+
def __setitem__(self, key, value):
19+
setattr(self, self.usr_key + key, value)
20+
21+
def data(self):
22+
__dict = self.__dict__
23+
output = {}
24+
for key in __dict.keys():
25+
if self.usr_key in key:
26+
without_usr_key = key.split(self.usr_key)[1]
27+
output[without_usr_key] = __dict[key]
28+
return output
29+
1530

1631
def available(self):
1732
if self.manifest_data:

examples/loaded_state_hello_world/src/Algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def load(modelData):
1717
# during runtime.
1818

1919
# Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
20-
modelData.user_data['payload'] = "Loading has been completed."
20+
modelData['payload'] = "Loading has been completed."
2121
return modelData
2222

2323

examples/pytorch_image_classification/src/Algorithm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def infer_image(image_url, n, globals):
5353

5454
def load(modelData):
5555

56-
modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
57-
modelData.user_data["model"] = load_model(modelData.get_model("squeezenet"))
58-
modelData.user_data["labels"] = load_labels(modelData.get_model("labels"))
56+
modelData["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
57+
modelData["model"] = load_model(modelData.get_model("squeezenet"))
58+
modelData["labels"] = load_labels(modelData.get_model("labels"))
5959
return modelData
6060

6161

@@ -67,10 +67,10 @@ def apply(input, modelData):
6767
n = 3
6868
if "data" in input:
6969
if isinstance(input["data"], str):
70-
output = infer_image(input["data"], n, modelData.user_data)
70+
output = infer_image(input["data"], n, modelData)
7171
elif isinstance(input["data"], list):
7272
for row in input["data"]:
73-
row["predictions"] = infer_image(row["image_url"], n, modelData.user_data)
73+
row["predictions"] = infer_image(row["image_url"], n, modelData)
7474
output = input["data"]
7575
else:
7676
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")

tests/adk_algorithms.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def apply_binary(input):
1616

1717
def apply_input_or_context(input, model_data=None):
1818
if model_data:
19-
return model_data.user_data
19+
return model_data.data()
2020
else:
2121
return "hello " + input
2222

@@ -30,7 +30,7 @@ def apply_successful_manifest_parsing(input, model_data):
3030

3131
# -- Loading functions --- #
3232
def loading_text(modelData):
33-
modelData.user_data['message'] = 'This message was loaded prior to runtime'
33+
modelData['message'] = 'This message was loaded prior to runtime'
3434
return modelData
3535

3636

@@ -39,14 +39,14 @@ def loading_exception(modelData):
3939

4040

4141
def loading_file_from_algorithmia(modelData):
42-
modelData.user_data['data_url'] = 'data://demo/collection/somefile.json'
43-
modelData.user_data['data'] = modelData.client.file(modelData.user_data['data_url']).getJson()
42+
modelData['data_url'] = 'data://demo/collection/somefile.json'
43+
modelData['data'] = modelData.client.file(modelData['data_url']).getJson()
4444
return modelData
4545

4646

4747
def loading_with_manifest(modelData):
48-
modelData.user_data["squeezenet"] = modelData.get_model("squeezenet")
49-
modelData.user_data['labels'] = modelData.get_model("labels")
48+
modelData["squeezenet"] = modelData.get_model("squeezenet")
49+
modelData['labels'] = modelData.get_model("labels")
5050
# optional model
51-
modelData.user_data['mobilenet'] = modelData.get_model("mobilenet")
51+
modelData['mobilenet'] = modelData.get_model("mobilenet")
5252
return modelData

0 commit comments

Comments
 (0)