@@ -122,7 +122,8 @@ algorithm.init("Algorithmia")
122
122
## [ pytorch based image classification] ( examples/pytorch_image_classification )
123
123
<!-- embedme examples/pytorch_image_classification/src/Algorithm.py -->
124
124
``` python
125
- from Algorithmia import client, ADK
125
+ from Algorithmia import ADK
126
+ import Algorithmia
126
127
import torch
127
128
from PIL import Image
128
129
import json
@@ -136,14 +137,10 @@ def load_labels(label_path, client):
136
137
return labels
137
138
138
139
139
- def load_model (name , model_paths , client ):
140
- if name == " squeezenet" :
141
- model = models.squeezenet1_1()
142
- models.densenet121()
143
- weights = torch.load(client.file(model_paths[" squeezenet" ]).getFile().name)
144
- else :
145
- model = models.alexnet()
146
- weights = torch.load(client.file(model_paths[" alexnet" ]).getFile().name)
140
+ def load_model (model_paths , client ):
141
+ model = models.squeezenet1_1()
142
+ local_file = client.file(model_paths[" filepath" ]).getFile().name
143
+ weights = torch.load(local_file)
147
144
model.load_state_dict(weights)
148
145
return model.float().eval()
149
146
@@ -177,17 +174,13 @@ def infer_image(image_url, n, globals):
177
174
return result
178
175
179
176
180
- def load ():
177
+ def load (manifest ):
178
+
181
179
globals = {}
182
- globals [" MODEL_PATHS" ] = {
183
- " squeezenet" : " data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth" ,
184
- " alexnet" : " data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth" ,
185
- }
186
- globals [" LABEL_PATHS" ] = " data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
187
- globals [" CLIENT" ] = client()
180
+ client = Algorithmia.client()
188
181
globals [" SMID_ALGO" ] = " algo://util/SmartImageDownloader/0.2.x"
189
- globals [" model" ] = load_model(" squeezenet" , globals [ " MODEL_PATHS " ], globals [ " CLIENT " ] )
190
- globals [" labels" ] = load_labels(globals [ " LABEL_PATHS " ], globals [ " CLIENT " ] )
182
+ globals [" model" ] = load_model(manifest[ " squeezenet" ], client )
183
+ globals [" labels" ] = load_labels(manifest[ " label_file " ], client )
191
184
return globals
192
185
193
186
@@ -205,10 +198,10 @@ def apply(input, globals):
205
198
row[" predictions" ] = infer_image(row[" image_url" ], n, globals )
206
199
output = input [" data" ]
207
200
else :
208
- raise Exception (" " data" must be a image url or a list of image urls (with labels)" )
201
+ raise Exception (" \ " data\ " must be a image url or a list of image urls (with labels)" )
209
202
return output
210
203
else :
211
- raise Exception (" " data" must be defined" )
204
+ raise Exception (" \ " data\ " must be defined" )
212
205
else :
213
206
raise Exception (" input must be a json object" )
214
207
0 commit comments