From 201197f46d661e3396261a4de828a860a7420b26 Mon Sep 17 00:00:00 2001 From: six-m Date: Thu, 9 Mar 2023 10:35:49 +0800 Subject: [PATCH] fix: remove all buffers registered besides the supported ops --- thop/profile.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/thop/profile.py b/thop/profile.py index 6b15d27..8f1aece 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -203,6 +203,10 @@ def add_hooks(m: nn.Module): ) types_collection.add(m_type) + def remove_buffers(m: nn.Module): + m._buffers.pop("total_ops") + m._buffers.pop("total_params") + prev_training_status = model.training model.eval() @@ -239,8 +243,7 @@ def dfs_count(module: nn.Module, prefix="\t") -> (int, int): for m, (op_handler, params_handler) in handler_collection.items(): op_handler.remove() params_handler.remove() - m._buffers.pop("total_ops") - m._buffers.pop("total_params") + model.apply(remove_buffers) if ret_layer_info: return total_ops, total_params, ret_dict