diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/.DS_Store differ diff --git a/benchmark/evaluate_transformer.py b/benchmark/evaluate_transformer.py new file mode 100644 index 0000000..4a83f27 --- /dev/null +++ b/benchmark/evaluate_transformer.py @@ -0,0 +1,26 @@ + +import torch.nn as nn +from thop import profile +import torch + +src = torch.rand((1, 1, 10)) # S,N,x + + +class ModelTransformer(nn.Module): + def __init__(self): + super(ModelTransformer, self).__init__() + self.linear1 = nn.Linear(10, 512) + self.linear2 = nn.Linear(10, 512) + self.transform = nn.Transformer( + d_model=512, nhead=8, num_encoder_layers=6) + + def forward(self, input): + input1 = self.linear1(input) + input2 = self.linear2(input) + output = self.transform(input1, input2) + return output + + +model = ModelTransformer() +macs, params = profile(model, inputs=(src, )) +print(macs, params) diff --git a/test.py b/test.py index c910ccd..56e42a9 100644 --- a/test.py +++ b/test.py @@ -4,5 +4,7 @@ m = torch.nn.Conv2d(128, 128, 1) x = torch.randn(1, 128, 16, 16) + flops = thop.profile(m, inputs=(x,), verbose=True) fprint(flops) + diff --git a/thop/profile.py b/thop/profile.py index cc1e8c4..6944781 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -67,7 +67,7 @@ def prYellow(skk): fprint("\033[93m{}\033[00m".format(skk)) nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, - + nn.Transformer: count_transformer, nn.Sequential: zero_ops, } diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index c00fd47..6238bf6 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -196,3 +196,84 @@ def count_lstm(m: nn.LSTM, x, y): total_ops *= batch_size m.total_ops += torch.DoubleTensor([int(total_ops)]) + + +def count_transformer(m: nn.Transformer, x, y): + total_ops = 0 + src, tgt = x + if m.batch_first: + num_steps = src.shape[0] + target = tgt.shape[1] + sequence = src.shape[1] + embedding = src.shape[2] + else: + target = tgt.shape[0] + sequence = src.shape[0] + num_steps = src.shape[1] + embedding = src.shape[2] + num_head = m.nhead + encoder_layers = m.encoder.num_layers + decoder_layers = m.decoder.num_layers + # dim_forward(default = 2048) + forward = m.encoder.layers[0].linear1.out_features + total_ops = 0 + + def multihead_attention(bool1, num_head, num_steps, target, sequence, embedding): + if bool1 == 0: + # linear_q,linear_k,linear_v all N,S,E + total_multi = 3 * sequence * embedding ** 2 + # self_attn softmax(Q*K_T/sqrt(dk))*V + total_multi += (sequence ** 4 * (embedding/num_head) ** 2 + + sequence ** 2 + sequence * (3 * sequence - 1) + 1) * num_head + # linear + total_multi += sequence * embedding ** 2 + # layernorm + total_multi += 2 * sequence * embedding + elif bool1 == 1: + # linear_q,linear_k,linear_v + total_multi = 3 * target * embedding ** 2 + # self_attn softmax(Q*K_T/sqrt(dk))*V + total_multi += (target ** 4 * (embedding/num_head) ** 2 + + target ** 2 + target * (3 * target-1) + 1) * num_head + total_multi += target * embedding ** 2 + total_multi += 2 * target * embedding + elif bool1 == 2: + # linear_q,linear_k,linear_v + total_multi = embedding ** 2 * (2 * sequence + target) + # self_attn softmax(Q*K_T/sqrt(dk))*V + total_multi += (target ** 2 * sequence ** 2 * (embedding/num_head) ** 2 + + target * sequence + target * (3 * sequence - 1)+1) * num_head + total_multi += target * embedding ** 2 + total_multi += 2 * target * embedding + # number of heads and batchsize + total_multi *= num_steps + return total_multi + + def transformer_encoder_layer(num_head, num_steps, target, sequence, embedding): + total_en = 0 + total_en += multihead_attention(0, num_head, + num_steps, target, sequence, embedding) + # fed_forward(2 conv1d) + total_en += num_steps * sequence * forward * embedding + total_en += num_steps * sequence * embedding * forward + # norm1 + total_en += 2 * num_steps * embedding * sequence + return total_en + + def transformer_decoder_layer(num_head, num_steps, target, sequence, embedding): + total_de = 0 + total_de += multihead_attention(1, num_head, + num_steps, target, sequence, embedding) + total_de += multihead_attention(2, num_head, + num_steps, target, sequence, embedding) + # linear1 linear2 fft + total_de += num_steps * target * forward * embedding + total_de += num_steps * target * embedding * forward + # layernorm + total_de += 2 * num_steps * embedding * target + return total_de + total_ops = encoder_layers * transformer_encoder_layer(num_head, num_steps, target, sequence, embedding) + \ + decoder_layers * \ + transformer_decoder_layer(num_head, num_steps, + target, sequence, embedding) + m.total_ops += torch.DoubleTensor([int(total_ops)])