File tree Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -223,19 +223,29 @@ def _get_layer_info(model, X):
223
223
224
224
"""
225
225
is_training = model .training
226
+
227
+ # Track the registered hooks so we can unregister them after running.
228
+ # NOTE: if not removed, hooks can prevent model from being sucessfully pickled.
229
+ hooks = []
230
+
231
+ # Track the layers and their input/output tensors for later reference.
226
232
layers = []
227
233
228
234
def hook (module , input , output , * args ):
229
- # layers[module] = (input, output)
230
235
layers .append ((module , input , output ))
231
236
232
237
for module in model .modules ():
233
- module .register_forward_hook (hook )
238
+ handle = module .register_forward_hook (hook )
239
+ hooks .append (handle )
234
240
235
241
model .eval ()
236
242
with torch .no_grad ():
237
243
model (X )
238
244
245
+ for handle in hooks :
246
+ handle .remove ()
247
+
248
+ model .train (is_training )
239
249
return layers
240
250
241
251
@property
You can’t perform that action at this time.
0 commit comments