Skip to content

Commit 2a22e99

Browse files
authored
Merge pull request #5 from algorithmiaio/advanced-example-update
moved all global functionality into load function dictionary
2 parents 128da0a + 63323f1 commit 2a22e99

File tree

3 files changed

+79
-78
lines changed

3 files changed

+79
-78
lines changed

README.md

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def load():
104104
# The return object from this function can be passed directly as input to your apply function.
105105
# A great example would be any model files that need to be available to this algorithm
106106
# during runtime.
107+
107108
# Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
108109
globals = {}
109110
globals['payload'] = "Loading has been completed."
@@ -121,53 +122,41 @@ algorithm.init("Algorithmia")
121122
## [pytorch based image classification](examples/pytorch_image_classification)
122123
<!-- embedme examples/pytorch_image_classification/src/Algorithm.py -->
123124
```python
124-
from Algorithmia import client, ADK
125+
from Algorithmia import ADK
126+
import Algorithmia
125127
import torch
126128
from PIL import Image
127129
import json
128130
from torchvision import models, transforms
129131

130-
CLIENT = client()
131-
SMID_ALGO = "algo://util/SmartImageDownloader/0.2.x"
132-
LABEL_PATH = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
133-
MODEL_PATHS = {
134-
"squeezenet": 'data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth',
135-
'alexnet': 'data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth',
136-
}
137-
138-
139-
def load_labels():
140-
local_path = CLIENT.file(LABEL_PATH).getFile().name
132+
def load_labels(label_path, client):
133+
local_path = client.file(label_path).getFile().name
141134
with open(local_path) as f:
142135
labels = json.load(f)
143136
labels = [labels[str(k)][1] for k in range(len(labels))]
144137
return labels
145138

146139

147-
def load_model(name):
148-
if name == "squeezenet":
149-
model = models.squeezenet1_1()
150-
models.densenet121()
151-
weights = torch.load(CLIENT.file(MODEL_PATHS['squeezenet']).getFile().name)
152-
else:
153-
model = models.alexnet()
154-
weights = torch.load(CLIENT.file(MODEL_PATHS['alexnet']).getFile().name)
140+
def load_model(model_paths, client):
141+
model = models.squeezenet1_1()
142+
local_file = client.file(model_paths["filepath"]).getFile().name
143+
weights = torch.load(local_file)
155144
model.load_state_dict(weights)
156145
return model.float().eval()
157146

158147

159-
def get_image(image_url):
160-
input = {"image": image_url, "resize": {'width': 224, 'height': 224}}
161-
result = CLIENT.algo(SMID_ALGO).pipe(input).result["savePath"][0]
162-
local_path = CLIENT.file(result).getFile().name
148+
def get_image(image_url, smid_algo, client):
149+
input = {"image": image_url, "resize": {"width": 224, "height": 224}}
150+
result = client.algo(smid_algo).pipe(input).result["savePath"][0]
151+
local_path = client.file(result).getFile().name
163152
img_data = Image.open(local_path)
164153
return img_data
165154

166155

167156
def infer_image(image_url, n, globals):
168-
model = globals['model']
169-
labels = globals['labels']
170-
image_data = get_image(image_url)
157+
model = globals["model"]
158+
labels = globals["labels"]
159+
image_data = get_image(image_url, globals["SMID_ALGO"], globals["CLIENT"])
171160
transformed = transforms.Compose([
172161
transforms.ToTensor(),
173162
transforms.Normalize(mean=[0.485, 0.456, 0.406],
@@ -185,31 +174,36 @@ def infer_image(image_url, n, globals):
185174
return result
186175

187176

188-
def load():
189-
globals = {'model': load_model("squeezenet"), 'labels': load_labels()}
177+
def load(manifest):
178+
179+
globals = {}
180+
client = Algorithmia.client()
181+
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
182+
globals["model"] = load_model(manifest["squeezenet"], client)
183+
globals["labels"] = load_labels(manifest["label_file"], client)
190184
return globals
191185

192186

193187
def apply(input, globals):
194188
if isinstance(input, dict):
195189
if "n" in input:
196-
n = input['n']
190+
n = input["n"]
197191
else:
198192
n = 3
199193
if "data" in input:
200-
if isinstance(input['data'], str):
201-
output = infer_image(input['data'], n, globals)
202-
elif isinstance(input['data'], list):
203-
for row in input['data']:
204-
row['predictions'] = infer_image(row['image_url'], n, globals)
205-
output = input['data']
194+
if isinstance(input["data"], str):
195+
output = infer_image(input["data"], n, globals)
196+
elif isinstance(input["data"], list):
197+
for row in input["data"]:
198+
row["predictions"] = infer_image(row["image_url"], n, globals)
199+
output = input["data"]
206200
else:
207-
raise Exception("'data' must be a image url or a list of image urls (with labels)")
201+
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
208202
return output
209203
else:
210-
raise Exception("'data' must be defined")
204+
raise Exception("\"data\" must be defined")
211205
else:
212-
raise Exception('input must be a json object')
206+
raise Exception("input must be a json object")
213207

214208

215209
algorithm = ADK(apply_func=apply, load_func=load)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"label_file": {
3+
"filepath": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
4+
"md5_hash": "c2c37ea517e94d9795004a39431a14cb",
5+
"origin_ref": "this file came from imagenet.org",
6+
"uploaded_utc": "2021-05-03-11:05"
7+
},
8+
"squeezenet": {
9+
"filepath": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
10+
"md5_hash": "46a44d32d2c5c07f7f66324bef4c7266",
11+
"origin_ref": "From https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth",
12+
"uploaded_utc": "2021-05-03-11:05"
13+
}
14+
}

examples/pytorch_image_classification/src/Algorithm.py

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,38 @@
1-
from Algorithmia import client, ADK
1+
from Algorithmia import ADK
2+
import Algorithmia
23
import torch
34
from PIL import Image
45
import json
56
from torchvision import models, transforms
67

7-
CLIENT = client()
8-
SMID_ALGO = "algo://util/SmartImageDownloader/0.2.x"
9-
LABEL_PATH = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
10-
MODEL_PATHS = {
11-
"squeezenet": 'data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth',
12-
'alexnet': 'data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth',
13-
}
14-
15-
16-
def load_labels():
17-
local_path = CLIENT.file(LABEL_PATH).getFile().name
8+
def load_labels(label_path, client):
9+
local_path = client.file(label_path).getFile().name
1810
with open(local_path) as f:
1911
labels = json.load(f)
2012
labels = [labels[str(k)][1] for k in range(len(labels))]
2113
return labels
2214

2315

24-
def load_model(name):
25-
if name == "squeezenet":
26-
model = models.squeezenet1_1()
27-
models.densenet121()
28-
weights = torch.load(CLIENT.file(MODEL_PATHS['squeezenet']).getFile().name)
29-
else:
30-
model = models.alexnet()
31-
weights = torch.load(CLIENT.file(MODEL_PATHS['alexnet']).getFile().name)
16+
def load_model(model_paths, client):
17+
model = models.squeezenet1_1()
18+
local_file = client.file(model_paths["filepath"]).getFile().name
19+
weights = torch.load(local_file)
3220
model.load_state_dict(weights)
3321
return model.float().eval()
3422

3523

36-
def get_image(image_url):
37-
input = {"image": image_url, "resize": {'width': 224, 'height': 224}}
38-
result = CLIENT.algo(SMID_ALGO).pipe(input).result["savePath"][0]
39-
local_path = CLIENT.file(result).getFile().name
24+
def get_image(image_url, smid_algo, client):
25+
input = {"image": image_url, "resize": {"width": 224, "height": 224}}
26+
result = client.algo(smid_algo).pipe(input).result["savePath"][0]
27+
local_path = client.file(result).getFile().name
4028
img_data = Image.open(local_path)
4129
return img_data
4230

4331

4432
def infer_image(image_url, n, globals):
45-
model = globals['model']
46-
labels = globals['labels']
47-
image_data = get_image(image_url)
33+
model = globals["model"]
34+
labels = globals["labels"]
35+
image_data = get_image(image_url, globals["SMID_ALGO"], globals["CLIENT"])
4836
transformed = transforms.Compose([
4937
transforms.ToTensor(),
5038
transforms.Normalize(mean=[0.485, 0.456, 0.406],
@@ -62,31 +50,36 @@ def infer_image(image_url, n, globals):
6250
return result
6351

6452

65-
def load():
66-
globals = {'model': load_model("squeezenet"), 'labels': load_labels()}
53+
def load(manifest):
54+
55+
globals = {}
56+
client = Algorithmia.client()
57+
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
58+
globals["model"] = load_model(manifest["squeezenet"], client)
59+
globals["labels"] = load_labels(manifest["label_file"], client)
6760
return globals
6861

6962

7063
def apply(input, globals):
7164
if isinstance(input, dict):
7265
if "n" in input:
73-
n = input['n']
66+
n = input["n"]
7467
else:
7568
n = 3
7669
if "data" in input:
77-
if isinstance(input['data'], str):
78-
output = infer_image(input['data'], n, globals)
79-
elif isinstance(input['data'], list):
80-
for row in input['data']:
81-
row['predictions'] = infer_image(row['image_url'], n, globals)
82-
output = input['data']
70+
if isinstance(input["data"], str):
71+
output = infer_image(input["data"], n, globals)
72+
elif isinstance(input["data"], list):
73+
for row in input["data"]:
74+
row["predictions"] = infer_image(row["image_url"], n, globals)
75+
output = input["data"]
8376
else:
84-
raise Exception("'data' must be a image url or a list of image urls (with labels)")
77+
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
8578
return output
8679
else:
87-
raise Exception("'data' must be defined")
80+
raise Exception("\"data\" must be defined")
8881
else:
89-
raise Exception('input must be a json object')
82+
raise Exception("input must be a json object")
9083

9184

9285
algorithm = ADK(apply_func=apply, load_func=load)

0 commit comments

Comments
 (0)