|
1 | 1 | using BenchmarkTools
|
2 | 2 | using NNlib
|
| 3 | +using NNlib.ChainRulesCore: rrule |
| 4 | +using Random |
| 5 | + |
| 6 | +Random.seed!(1234567890) |
3 | 7 |
|
4 | 8 | const SUITE = BenchmarkGroup()
|
5 | 9 |
|
6 | 10 | SUITE["activations"] = BenchmarkGroup()
|
| 11 | +for et in (Float16, Float32, Float64) |
| 12 | + et_suite = BenchmarkGroup() |
| 13 | + SUITE["activations"][string(et)] = et_suite |
| 14 | + let x = rand(et, 1024, 1024), y = similar(x) |
| 15 | + for f in NNlib.ACTIVATIONS |
| 16 | + act = @eval($f) |
| 17 | + et_suite[string(f)] = @benchmarkable broadcast!($act, $y, $x) |
| 18 | + end |
| 19 | + end |
| 20 | +end |
7 | 21 |
|
8 |
| -x = rand(64, 64) |
9 |
| - |
10 |
| -for f in NNlib.ACTIVATIONS |
11 |
| - act = @eval($f) |
12 |
| - SUITE["activations"][string(f)] = @benchmarkable $act.($x) |
| 22 | +for (fn!, fn_bw) in [(softmax!, NNlib.∇softmax_data), (logsoftmax!, NNlib.∇logsoftmax_data)] |
| 23 | + fn_suite = BenchmarkGroup() |
| 24 | + SUITE[rstrip(string(fn!), '!')] = fn_suite |
| 25 | + let SIZES = [ |
| 26 | + (128, 384, 8), |
| 27 | + (512, 784, 8), |
| 28 | + (768, 1024, 4), |
| 29 | + (1024, 2048, 4), |
| 30 | + (2048, 2048, 2), |
| 31 | + (4096, 2048, 2), |
| 32 | + (4096, 4096, 2), |
| 33 | + (12288, 2048, 1) |
| 34 | + ] |
| 35 | + for et in (Float16, Float32) |
| 36 | + et_suite = BenchmarkGroup("fw" => BenchmarkGroup(), "bw" => BenchmarkGroup()) |
| 37 | + fn_suite[string(et)] = et_suite |
| 38 | + for sz in SIZES |
| 39 | + x = randn(et, sz) |
| 40 | + y = similar(x) |
| 41 | + dy = zero(x) |
| 42 | + fn!(y, x) |
| 43 | + et_suite["fw"][string(sz)] = @benchmarkable $fn!($y, $x) |
| 44 | + et_suite["bw"][string(sz)] = @benchmarkable $fn_bw($dy, $y) |
| 45 | + end |
| 46 | + end |
| 47 | + end |
13 | 48 | end
|
| 49 | + |
0 commit comments