Skip to content

Commit 8d6d221

Browse files
committed
param counter
1 parent bea65c7 commit 8d6d221

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

models/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from torch import nn
3+
from thop import profile, clever_format
34

45

56
def count_parameters(model):
@@ -12,3 +13,8 @@ def init_weights(m):
1213
nn.init.xavier_normal_(m.weight)
1314
if m.bias is not None:
1415
nn.init.constant_(m.bias, 0.0)
16+
17+
18+
def params_flops(model, inputs):
19+
flops, params = profile(model, (inputs, ))
20+
clever_format([flops, params], "%.3f")

0 commit comments

Comments
 (0)