Skip to content

Commit a4daf10

Browse files
committed
add (log)softmax benchmarks
1 parent acf87f5 commit a4daf10

File tree

1 file changed

+37
-5
lines changed

1 file changed

+37
-5
lines changed

benchmark/benchmarks.jl

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,45 @@
11
using BenchmarkTools
22
using NNlib
3+
using NNlib.ChainRulesCore: rrule
4+
using Random
5+
6+
Random.seed!(1234567890)
37

48
const SUITE = BenchmarkGroup()
59

610
SUITE["activations"] = BenchmarkGroup()
11+
let x = rand(64, 64)
12+
for f in NNlib.ACTIVATIONS
13+
act = @eval($f)
14+
SUITE["activations"][string(f)] = @benchmarkable $act.($x)
15+
end
16+
end
717

8-
x = rand(64, 64)
9-
10-
for f in NNlib.ACTIVATIONS
11-
act = @eval($f)
12-
SUITE["activations"][string(f)] = @benchmarkable $act.($x)
18+
for (fn!, fn_bw) in [(softmax!, NNlib.∇softmax_data), (logsoftmax!, NNlib.∇logsoftmax_data)]
19+
fn_suite = BenchmarkGroup()
20+
SUITE[rstrip(string(fn!), '!')] = fn_suite
21+
let SIZES = [
22+
(128, 384, 8),
23+
(512, 784, 8),
24+
(768, 1024, 4),
25+
(1024, 2048, 4),
26+
(2048, 2048, 2),
27+
(4096, 2048, 2),
28+
(4096, 4096, 2),
29+
(12288, 2048, 1)
30+
]
31+
for et in (Float16, Float32)
32+
et_suite = BenchmarkGroup("fw" => BenchmarkGroup(), "bw" => BenchmarkGroup())
33+
fn_suite[string(et)] = et_suite
34+
for sz in SIZES
35+
x = randn(et, sz)
36+
y = similar(x)
37+
dy = zero(x)
38+
fn!(y, x)
39+
et_suite["fw"][string(sz)] = @benchmarkable $fn!($y, $x)
40+
et_suite["bw"][string(sz)] = @benchmarkable $fn_bw($dy, $y)
41+
end
42+
end
43+
end
1344
end
45+

0 commit comments

Comments
 (0)