Skip to content

Commit 5c130bf

Browse files
authored
Add/add v2 op (#11)
* add AddV2 op
1 parent f67795a commit 5c130bf

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

keras_flops/flops_registory.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import numpy as np
22
from tensorflow.python.framework import ops
33
from tensorflow.python.framework import graph_util
4-
from tensorflow.python.profiler.internal.flops_registry import _reduction_op_flops
4+
from tensorflow.python.profiler.internal.flops_registry import (
5+
_reduction_op_flops,
6+
_binary_per_element_op_flops,
7+
)
58

69

710
@ops.RegisterStatistics("FusedBatchNormV3", "flops")
@@ -18,8 +21,8 @@ def _flops_fused_batch_norm_v3(graph, node):
1821
raise ValueError("Only supports inference mode")
1922

2023
num_flops = (
21-
in_shape.num_elements()
22-
+ 4 * variance_shape.num_elements()
24+
2 * in_shape.num_elements()
25+
+ 5 * variance_shape.num_elements()
2326
+ mean_shape.num_elements()
2427
)
2528
return ops.OpStats("flops", num_flops)
@@ -31,3 +34,9 @@ def _flops_max(graph, node):
3134
# reduction - comparison, no finalization
3235
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
3336

37+
38+
@ops.RegisterStatistics("AddV2", "flops")
39+
def _flops_add(graph, node):
40+
"""inference is supportted"""
41+
return _binary_per_element_op_flops(graph, node)
42+

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
22
name = "keras-flops"
3-
version = "0.1.1"
4-
description = "FLOPs calculator with tf.profiler for neural network architecture written in tensorflow 2.x (tf.keras)"
3+
version = "0.1.2"
4+
description = "FLOPs calculator with tf.profiler for neural network architecture written in tensorflow 2.2+ (tf.keras)"
55
authors = ["tokusumi <tksmtoms@gmail.com>"]
66
license = "MIT"
77
readme = "README.md"

tests/test_flops.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,11 @@ def test_conv1dtranspose():
315315
def test_batchnormalization():
316316
"""
317317
batch normalization is calculated as follows,
318-
1. (2 ops * |var|) inv = rsqrt(var + eps)
318+
1. (3 ops * |var|) inv = rsqrt(var + eps)
319319
2. (1 ops * |var|) inv *= gamma (scale)
320-
3. (|x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
320+
3. (2 * |x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
321321
, where |var| = |mean| = channel size in default
322-
Thus, tot FLOPs = 5 * channel size + input element size.
322+
Thus, tot FLOPs = 6 * channel size + 2 * input element size.
323323
"""
324324
in_w = 32
325325
in_h = 32
@@ -334,7 +334,7 @@ def test_batchnormalization():
334334
)
335335
flops = get_flops(model, batch_size=1)
336336
assert (
337-
flops == 5 * in_ch + in_w * in_ch
337+
flops == 6 * in_ch + 2 * in_w * in_ch
338338
), "fused is False. see nn_impl.batch_normalization"
339339

340340
model = Sequential(
@@ -346,7 +346,7 @@ def test_batchnormalization():
346346
)
347347
flops = get_flops(model, batch_size=1)
348348
assert (
349-
flops == 5 * in_ch + in_w * in_h * in_ch
349+
flops == 6 * in_ch + 2 * in_w * in_h * in_ch
350350
), "fused is True, see gen_nn.fused_batch_norm_v3"
351351

352352

@@ -355,7 +355,7 @@ def test_additive_attention():
355355
Bahdanau-style attention. query (batch, Tq, dim), key (batch, Tv, dim) and value (batch, Tv, dim) are inputs.
356356
following computations is processed.
357357
1. reshape query as shape [batch, Tq, 1, dim] and value as shape [batch, 1, Tv, dim]
358-
2. broadcasting multiply between both of above as output shape [batch, Tq, Tv, dim]
358+
2. broadcasting multiply between additive of above as output shape [batch, Tq, Tv, dim]
359359
3. reduce_sum above with dim axis as output shape [batch, Tq, Tv]
360360
4. softmax of above
361361
5. MatMul between 4. and value as output shape [batch, Tq, dim]
@@ -375,6 +375,7 @@ def test_additive_attention():
375375
assert (
376376
flops
377377
== Tq * Tv * dim # No.2 (multiply)
378+
+ Tq * Tv * dim # No.3 (add)
378379
+ Tq * Tv * (dim - 1) # No.3 (reduce_sum)
379380
+ 5 * Tq * Tv # No.4 (softmax)
380381
+ 2 * Tv * Tq * dim # No.5 (MatMul)

0 commit comments

Comments
 (0)