@@ -104,6 +104,7 @@ def load():
104
104
# The return object from this function can be passed directly as input to your apply function.
105
105
# A great example would be any model files that need to be available to this algorithm
106
106
# during runtime.
107
+
107
108
# Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
108
109
globals = {}
109
110
globals [' payload' ] = " Loading has been completed."
@@ -121,53 +122,41 @@ algorithm.init("Algorithmia")
121
122
## [ pytorch based image classification] ( examples/pytorch_image_classification )
122
123
<!-- embedme examples/pytorch_image_classification/src/Algorithm.py -->
123
124
``` python
124
- from Algorithmia import client, ADK
125
+ from Algorithmia import ADK
126
+ import Algorithmia
125
127
import torch
126
128
from PIL import Image
127
129
import json
128
130
from torchvision import models, transforms
129
131
130
- CLIENT = client()
131
- SMID_ALGO = " algo://util/SmartImageDownloader/0.2.x"
132
- LABEL_PATH = " data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
133
- MODEL_PATHS = {
134
- " squeezenet" : ' data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth' ,
135
- ' alexnet' : ' data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth' ,
136
- }
137
-
138
-
139
- def load_labels ():
140
- local_path = CLIENT .file(LABEL_PATH ).getFile().name
132
+ def load_labels (label_path , client ):
133
+ local_path = client.file(label_path).getFile().name
141
134
with open (local_path) as f:
142
135
labels = json.load(f)
143
136
labels = [labels[str (k)][1 ] for k in range (len (labels))]
144
137
return labels
145
138
146
139
147
- def load_model (name ):
148
- if name == " squeezenet" :
149
- model = models.squeezenet1_1()
150
- models.densenet121()
151
- weights = torch.load(CLIENT .file(MODEL_PATHS [' squeezenet' ]).getFile().name)
152
- else :
153
- model = models.alexnet()
154
- weights = torch.load(CLIENT .file(MODEL_PATHS [' alexnet' ]).getFile().name)
140
+ def load_model (model_paths , client ):
141
+ model = models.squeezenet1_1()
142
+ local_file = client.file(model_paths[" filepath" ]).getFile().name
143
+ weights = torch.load(local_file)
155
144
model.load_state_dict(weights)
156
145
return model.float().eval()
157
146
158
147
159
- def get_image (image_url ):
160
- input = {" image" : image_url, " resize" : {' width' : 224 , ' height' : 224 }}
161
- result = CLIENT .algo(SMID_ALGO ).pipe(input ).result[" savePath" ][0 ]
162
- local_path = CLIENT .file(result).getFile().name
148
+ def get_image (image_url , smid_algo , client ):
149
+ input = {" image" : image_url, " resize" : {" width" : 224 , " height" : 224 }}
150
+ result = client .algo(smid_algo ).pipe(input ).result[" savePath" ][0 ]
151
+ local_path = client .file(result).getFile().name
163
152
img_data = Image.open(local_path)
164
153
return img_data
165
154
166
155
167
156
def infer_image (image_url , n , globals ):
168
- model = globals [' model' ]
169
- labels = globals [' labels' ]
170
- image_data = get_image(image_url)
157
+ model = globals [" model" ]
158
+ labels = globals [" labels" ]
159
+ image_data = get_image(image_url, globals [ " SMID_ALGO " ], globals [ " CLIENT " ] )
171
160
transformed = transforms.Compose([
172
161
transforms.ToTensor(),
173
162
transforms.Normalize(mean = [0.485 , 0.456 , 0.406 ],
@@ -185,31 +174,36 @@ def infer_image(image_url, n, globals):
185
174
return result
186
175
187
176
188
- def load ():
189
- globals = {' model' : load_model(" squeezenet" ), ' labels' : load_labels()}
177
+ def load (manifest ):
178
+
179
+ globals = {}
180
+ client = Algorithmia.client()
181
+ globals [" SMID_ALGO" ] = " algo://util/SmartImageDownloader/0.2.x"
182
+ globals [" model" ] = load_model(manifest[" squeezenet" ], client)
183
+ globals [" labels" ] = load_labels(manifest[" label_file" ], client)
190
184
return globals
191
185
192
186
193
187
def apply (input , globals ):
194
188
if isinstance (input , dict ):
195
189
if " n" in input :
196
- n = input [' n ' ]
190
+ n = input [" n " ]
197
191
else :
198
192
n = 3
199
193
if " data" in input :
200
- if isinstance (input [' data' ], str ):
201
- output = infer_image(input [' data' ], n, globals )
202
- elif isinstance (input [' data' ], list ):
203
- for row in input [' data' ]:
204
- row[' predictions' ] = infer_image(row[' image_url' ], n, globals )
205
- output = input [' data' ]
194
+ if isinstance (input [" data" ], str ):
195
+ output = infer_image(input [" data" ], n, globals )
196
+ elif isinstance (input [" data" ], list ):
197
+ for row in input [" data" ]:
198
+ row[" predictions" ] = infer_image(row[" image_url" ], n, globals )
199
+ output = input [" data" ]
206
200
else :
207
- raise Exception (" ' data' must be a image url or a list of image urls (with labels)" )
201
+ raise Exception (" \" data\" must be a image url or a list of image urls (with labels)" )
208
202
return output
209
203
else :
210
- raise Exception (" ' data' must be defined" )
204
+ raise Exception (" \" data\" must be defined" )
211
205
else :
212
- raise Exception (' input must be a json object' )
206
+ raise Exception (" input must be a json object" )
213
207
214
208
215
209
algorithm = ADK(apply_func = apply, load_func = load)
0 commit comments