@@ -129,18 +129,19 @@ from PIL import Image
129
129
import json
130
130
from torchvision import models, transforms
131
131
132
- def load_labels (label_path , client ):
133
- local_path = client.file(label_path).getFile().name
134
- with open (local_path) as f:
132
+
133
+ client = Algorithmia.client()
134
+
135
+ def load_labels (label_path ):
136
+ with open (label_path) as f:
135
137
labels = json.load(f)
136
138
labels = [labels[str (k)][1 ] for k in range (len (labels))]
137
139
return labels
138
140
139
141
140
- def load_model (model_paths , client ):
142
+ def load_model (model_path ):
141
143
model = models.squeezenet1_1()
142
- local_file = client.file(model_paths[" filepath" ]).getFile().name
143
- weights = torch.load(local_file)
144
+ weights = torch.load(model_path)
144
145
model.load_state_dict(weights)
145
146
return model.float().eval()
146
147
@@ -176,26 +177,25 @@ def infer_image(image_url, n, globals):
176
177
177
178
def load (manifest ):
178
179
179
- globals = {}
180
- client = Algorithmia.client()
181
- globals [" SMID_ALGO" ] = " algo://util/SmartImageDownloader/0.2.x"
182
- globals [" model" ] = load_model(manifest[" squeezenet" ], client)
183
- globals [" labels" ] = load_labels(manifest[" label_file" ], client)
184
- return globals
180
+ state = {}
181
+ state[" SMID_ALGO" ] = " algo://util/SmartImageDownloader/0.2.x"
182
+ state[" model" ] = load_model(manifest.get_model(" squeezenet" ))
183
+ state[" labels" ] = load_labels(manifest.get_model(" labels" ))
184
+ return state
185
185
186
186
187
- def apply (input , globals ):
187
+ def apply (input , state ):
188
188
if isinstance (input , dict ):
189
189
if " n" in input :
190
190
n = input [" n" ]
191
191
else :
192
192
n = 3
193
193
if " data" in input :
194
194
if isinstance (input [" data" ], str ):
195
- output = infer_image(input [" data" ], n, globals )
195
+ output = infer_image(input [" data" ], n, state )
196
196
elif isinstance (input [" data" ], list ):
197
197
for row in input [" data" ]:
198
- row[" predictions" ] = infer_image(row[" image_url" ], n, globals )
198
+ row[" predictions" ] = infer_image(row[" image_url" ], n, state )
199
199
output = input [" data" ]
200
200
else :
201
201
raise Exception (" \" data\" must be a image url or a list of image urls (with labels)" )
@@ -206,7 +206,7 @@ def apply(input, globals):
206
206
raise Exception (" input must be a json object" )
207
207
208
208
209
- algorithm = ADK(apply_func = apply, load_func = load)
209
+ algorithm = ADK(apply_func = apply, load_func = load, client = client )
210
210
algorithm.init({" data" : " https://i.imgur.com/bXdORXl.jpeg" })
211
211
212
212
```
0 commit comments