Skip to content

Commit 63323f1

Browse files
committed
updated readme
1 parent 2de51dc commit 63323f1

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

README.md

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ algorithm.init("Algorithmia")
122122
## [pytorch based image classification](examples/pytorch_image_classification)
123123
<!-- embedme examples/pytorch_image_classification/src/Algorithm.py -->
124124
```python
125-
from Algorithmia import client, ADK
125+
from Algorithmia import ADK
126+
import Algorithmia
126127
import torch
127128
from PIL import Image
128129
import json
@@ -136,14 +137,10 @@ def load_labels(label_path, client):
136137
return labels
137138

138139

139-
def load_model(name, model_paths, client):
140-
if name == "squeezenet":
141-
model = models.squeezenet1_1()
142-
models.densenet121()
143-
weights = torch.load(client.file(model_paths["squeezenet"]).getFile().name)
144-
else:
145-
model = models.alexnet()
146-
weights = torch.load(client.file(model_paths["alexnet"]).getFile().name)
140+
def load_model(model_paths, client):
141+
model = models.squeezenet1_1()
142+
local_file = client.file(model_paths["filepath"]).getFile().name
143+
weights = torch.load(local_file)
147144
model.load_state_dict(weights)
148145
return model.float().eval()
149146

@@ -177,17 +174,13 @@ def infer_image(image_url, n, globals):
177174
return result
178175

179176

180-
def load():
177+
def load(manifest):
178+
181179
globals = {}
182-
globals["MODEL_PATHS"] = {
183-
"squeezenet": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
184-
"alexnet": "data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth",
185-
}
186-
globals["LABEL_PATHS"] = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
187-
globals["CLIENT"] = client()
180+
client = Algorithmia.client()
188181
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
189-
globals["model"] = load_model("squeezenet", globals["MODEL_PATHS"], globals["CLIENT"])
190-
globals["labels"] = load_labels(globals["LABEL_PATHS"], globals["CLIENT"])
182+
globals["model"] = load_model(manifest["squeezenet"], client)
183+
globals["labels"] = load_labels(manifest["label_file"], client)
191184
return globals
192185

193186

@@ -205,10 +198,10 @@ def apply(input, globals):
205198
row["predictions"] = infer_image(row["image_url"], n, globals)
206199
output = input["data"]
207200
else:
208-
raise Exception(""data" must be a image url or a list of image urls (with labels)")
201+
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
209202
return output
210203
else:
211-
raise Exception(""data" must be defined")
204+
raise Exception("\"data\" must be defined")
212205
else:
213206
raise Exception("input must be a json object")
214207

0 commit comments

Comments
 (0)