-
Notifications
You must be signed in to change notification settings - Fork 531
Add transformer counter to profile #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 17 commits
b25be5e
03064d2
fd9f2ec
d33dbe2
d632a80
e1633e6
8918cf5
f1a7805
9194cd2
f6fa0b0
dd9f41d
9d1d0d3
2034b86
9492d96
d0b58ca
0afd8aa
c3ac15e
e90599d
a9181d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
|
||
import torch.nn as nn | ||
from thop import profile | ||
import torch | ||
|
||
src = torch.rand((1, 1, 10)) # S,N,x | ||
|
||
|
||
class Model_transformer(nn.Module): | ||
def __init__(self): | ||
super(Model_transformer, self).__init__() | ||
self.linear1 = nn.Linear(10, 512) | ||
self.linear2 = nn.Linear(10, 512) | ||
self.transform = nn.Transformer( | ||
d_model=512, nhead=8, num_encoder_layers=6) | ||
|
||
def forward(self, input): | ||
input1 = self.linear1(input) | ||
input2 = self.linear2(input) | ||
output = self.transform(input1, input2) | ||
return output | ||
|
||
|
||
model = Model_transformer() | ||
macs, params = profile(model, inputs=(src, )) | ||
print(macs, params) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,7 +67,7 @@ def prYellow(skk): fprint("\033[93m{}\033[00m".format(skk)) | |
nn.RNN: count_rnn, | ||
nn.GRU: count_gru, | ||
nn.LSTM: count_lstm, | ||
|
||
nn.Transformer: count_Transformer, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. function name should be lower case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
nn.Sequential: zero_ops, | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -196,3 +196,84 @@ def count_lstm(m: nn.LSTM, x, y): | |
total_ops *= batch_size | ||
|
||
m.total_ops += torch.DoubleTensor([int(total_ops)]) | ||
|
||
|
||
def count_Transformer(m: nn.Transformer, x, y): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same issue here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed, also changed its subfunction, sorry for forgeting changing it after learning camelcase |
||
total_ops = 0 | ||
src, tgt = x | ||
if m.batch_first: | ||
num_steps = src.shape[0] | ||
target = tgt.shape[1] | ||
sequence = src.shape[1] | ||
embedding = src.shape[2] | ||
else: | ||
target = tgt.shape[0] | ||
sequence = src.shape[0] | ||
num_steps = src.shape[1] | ||
embedding = src.shape[2] | ||
num_head = m.nhead | ||
encoder_layers = m.encoder.num_layers | ||
decoder_layers = m.decoder.num_layers | ||
# dim_forward(default = 2048) | ||
forward = m.encoder.layers[0].linear1.out_features | ||
total_ops = 0 | ||
|
||
def MultiheadAttention(bool1, num_head, num_steps, target, sequence, embedding): | ||
if bool1 == 0: | ||
# linear_q,linear_k,linear_v all N,S,E | ||
total_multi = 3 * sequence * embedding ** 2 | ||
# self_attn softmax(Q*K_T/sqrt(dk))*V | ||
total_multi += (sequence ** 4 * (embedding/num_head) ** 2 + | ||
sequence ** 2 + sequence * (3 * sequence - 1) + 1) * num_head | ||
# linear | ||
total_multi += sequence * embedding ** 2 | ||
# layernorm | ||
total_multi += 2 * sequence * embedding | ||
elif bool1 == 1: | ||
# linear_q,linear_k,linear_v | ||
total_multi = 3 * target * embedding ** 2 | ||
# self_attn softmax(Q*K_T/sqrt(dk))*V | ||
total_multi += (target ** 4 * (embedding/num_head) ** 2 + | ||
target ** 2 + target * (3 * target-1) + 1) * num_head | ||
total_multi += target * embedding ** 2 | ||
total_multi += 2 * target * embedding | ||
elif bool1 == 2: | ||
# linear_q,linear_k,linear_v | ||
total_multi = embedding ** 2 * (2 * sequence + target) | ||
# self_attn softmax(Q*K_T/sqrt(dk))*V | ||
total_multi += (target ** 2 * sequence ** 2 * (embedding/num_head) ** 2 + | ||
target * sequence + target * (3 * sequence - 1)+1) * num_head | ||
total_multi += target * embedding ** 2 | ||
total_multi += 2 * target * embedding | ||
# number of heads and batchsize | ||
total_multi *= num_steps | ||
return total_multi | ||
|
||
def TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding): | ||
total_en = 0 | ||
total_en += MultiheadAttention(0, num_head, | ||
num_steps, target, sequence, embedding) | ||
# fed_forward(2 conv1d) | ||
total_en += num_steps * sequence * forward * embedding | ||
total_en += num_steps * sequence * embedding * forward | ||
# norm1 | ||
total_en += 2 * num_steps * embedding * sequence | ||
return total_en | ||
|
||
def TransformerDecoderLayer(num_head, num_steps, target, sequence, embedding): | ||
total_de = 0 | ||
total_de += MultiheadAttention(1, num_head, | ||
num_steps, target, sequence, embedding) | ||
total_de += MultiheadAttention(2, num_head, | ||
num_steps, target, sequence, embedding) | ||
# linear1 linear2 fft | ||
total_de += num_steps * target * forward * embedding | ||
total_de += num_steps * target * embedding * forward | ||
# layernorm | ||
total_de += 2 * num_steps * embedding * target | ||
return total_de | ||
total_ops = encoder_layers * TransformerEncoderLayer(num_head, num_steps, target, sequence, embedding) + \ | ||
decoder_layers * \ | ||
TransformerDecoderLayer(num_head, num_steps, | ||
target, sequence, embedding) | ||
m.total_ops += torch.DoubleTensor([int(total_ops)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Class name should be CamelCased.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed