From 5cee692c6093a96d65bc280447f12ae2993b1463 Mon Sep 17 00:00:00 2001 From: chenghaoDong666 <1785246872@qq.com> Date: Mon, 3 May 2021 00:32:59 +0800 Subject: [PATCH] fix #119 fix #114 fix #112 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更改了代码逻辑,现在在计算ops和params时,每个模块会包含两部分,子模块的部分和除了子模块之外的部分;除此之外还对README.md有略微的修改 --- README.md | 5 +++-- thop/profile.py | 13 +++++++++++-- thop/vision/basic_hooks.py | 5 +++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 8a2da2e..5700a11 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ OR ## How to use * Basic usage ```python + import torch from torchvision.models import resnet50 from thop import profile model = resnet50() @@ -24,7 +25,7 @@ OR # your definition def count_your_model(model, x, y): # your rule here - + # Note that your rule only calculate the ops and params except its submodule's ops and params input = torch.randn(1, 3, 224, 224) macs, params = profile(model, inputs=(input, ), custom_ops={YourModule: count_your_model}) @@ -92,4 +93,4 @@ inception_v3 | 27.16 | 5.75 -

+

\ No newline at end of file diff --git a/thop/profile.py b/thop/profile.py index 4b98364..e1efa34 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -191,7 +191,16 @@ def add_hooks(m: nn.Module): model(*inputs) def dfs_count(module: nn.Module, prefix="\t") -> (int, int): - total_ops, total_params = 0, 0 + """ + calculate the ops and params through dfs + For each module's ops and params,it contains two part: + 1) the ops and params of its submodule + 2) the ops and params except 1) + :param module: the module + :param prefix: the prefix + :return: total_ops, total_params + """ + total_ops, total_params = module.total_ops, module.total_params for m in module.children(): # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0: # m_ops, m_params = dfs_count(m, prefix=prefix + "\t") @@ -204,7 +213,7 @@ def dfs_count(module: nn.Module, prefix="\t") -> (int, int): total_ops += m_ops total_params += m_params # print(prefix, module._get_name(), (total_ops.item(), total_params.item())) - return total_ops, total_params + return total_ops.item(), total_params.item() total_ops, total_params = dfs_count(model) diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index e3d7d7d..be4272a 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -10,9 +10,10 @@ def count_parameters(m, x, y): total_params = 0 - for p in m.parameters(): + for p in m.parameters(recurse = False): total_params += torch.DoubleTensor([p.numel()]) - m.total_params[0] = total_params + if type(total_params) != int: + m.total_params += total_params def zero_ops(m, x, y):