@@ -315,11 +315,11 @@ def test_conv1dtranspose():
315
315
def test_batchnormalization ():
316
316
"""
317
317
batch normalization is calculated as follows,
318
- 1. (2 ops * |var|) inv = rsqrt(var + eps)
318
+ 1. (3 ops * |var|) inv = rsqrt(var + eps)
319
319
2. (1 ops * |var|) inv *= gamma (scale)
320
- 3. (|x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
320
+ 3. (2 * |x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
321
321
, where |var| = |mean| = channel size in default
322
- Thus, tot FLOPs = 5 * channel size + input element size.
322
+ Thus, tot FLOPs = 6 * channel size + 2 * input element size.
323
323
"""
324
324
in_w = 32
325
325
in_h = 32
@@ -334,7 +334,7 @@ def test_batchnormalization():
334
334
)
335
335
flops = get_flops (model , batch_size = 1 )
336
336
assert (
337
- flops == 5 * in_ch + in_w * in_ch
337
+ flops == 6 * in_ch + 2 * in_w * in_ch
338
338
), "fused is False. see nn_impl.batch_normalization"
339
339
340
340
model = Sequential (
@@ -346,7 +346,7 @@ def test_batchnormalization():
346
346
)
347
347
flops = get_flops (model , batch_size = 1 )
348
348
assert (
349
- flops == 5 * in_ch + in_w * in_h * in_ch
349
+ flops == 6 * in_ch + 2 * in_w * in_h * in_ch
350
350
), "fused is True, see gen_nn.fused_batch_norm_v3"
351
351
352
352
@@ -355,7 +355,7 @@ def test_additive_attention():
355
355
Bahdanau-style attention. query (batch, Tq, dim), key (batch, Tv, dim) and value (batch, Tv, dim) are inputs.
356
356
following computations is processed.
357
357
1. reshape query as shape [batch, Tq, 1, dim] and value as shape [batch, 1, Tv, dim]
358
- 2. broadcasting multiply between both of above as output shape [batch, Tq, Tv, dim]
358
+ 2. broadcasting multiply between additive of above as output shape [batch, Tq, Tv, dim]
359
359
3. reduce_sum above with dim axis as output shape [batch, Tq, Tv]
360
360
4. softmax of above
361
361
5. MatMul between 4. and value as output shape [batch, Tq, dim]
@@ -375,6 +375,7 @@ def test_additive_attention():
375
375
assert (
376
376
flops
377
377
== Tq * Tv * dim # No.2 (multiply)
378
+ + Tq * Tv * dim # No.3 (add)
378
379
+ Tq * Tv * (dim - 1 ) # No.3 (reduce_sum)
379
380
+ 5 * Tq * Tv # No.4 (softmax)
380
381
+ 2 * Tv * Tq * dim # No.5 (MatMul)
0 commit comments