Skip to content

Commit 816c0c2

Browse files
committed
fix: remove hooks
1 parent acc128b commit 816c0c2

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/sasctl/utils/model_info.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,19 +223,29 @@ def _get_layer_info(model, X):
223223
224224
"""
225225
is_training = model.training
226+
227+
# Track the registered hooks so we can unregister them after running.
228+
# NOTE: if not removed, hooks can prevent model from being sucessfully pickled.
229+
hooks = []
230+
231+
# Track the layers and their input/output tensors for later reference.
226232
layers = []
227233

228234
def hook(module, input, output, *args):
229-
# layers[module] = (input, output)
230235
layers.append((module, input, output))
231236

232237
for module in model.modules():
233-
module.register_forward_hook(hook)
238+
handle = module.register_forward_hook(hook)
239+
hooks.append(handle)
234240

235241
model.eval()
236242
with torch.no_grad():
237243
model(X)
238244

245+
for handle in hooks:
246+
handle.remove()
247+
248+
model.train(is_training)
239249
return layers
240250

241251
@property

0 commit comments

Comments
 (0)