Skip to content

Commit da2455f

Browse files
committed
Adding recurrent perf benchmarks for RNN.
1 parent 1930966 commit da2455f

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

perf/bench_utils.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
using BenchmarkTools
22
using Flux
33
using CUDA
4-
using Zygote: pullback
4+
using Zygote: pullback, ignore
55

66

77
fw(m, x) = m(x)
88
bw(back) = back(1f0)
99
fwbw(m, ps, x) = gradient(() -> sum(m(x)), ps)
10+
11+
# Need to specialize for flux.recur.
12+
fw(m::Flux.Recur, X::Vector{<:AbstractArray}) = begin
13+
ignore() do
14+
Flux.reset!(m)
15+
end
16+
[m(x) for x in X]
17+
end
18+
fwbw(m::Flux.Recur, ps, X::Vector{<:AbstractArray}) = gradient(ps) do
19+
y = fw(m, X)
20+
sum(sum(y))
21+
end
1022

1123
function run_benchmark(model, x; cuda=true)
1224

@@ -16,7 +28,11 @@ function run_benchmark(model, x; cuda=true)
1628
end
1729

1830
ps = Flux.params(model)
19-
y, back = pullback(() -> sum(model(x)), ps)
31+
y, back = if model isa Flux.Recur
32+
pullback(() -> sum(sum([model(x_t) for x_t in x])), ps)
33+
else
34+
pullback(() -> sum(model(x)), ps)
35+
end
2036

2137

2238
if cuda

perf/recurrent.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
println("RNN")
3+
for n in [2, 20, 200, 2000], T in [1, 8, 16, 64]
4+
x = [randn(Float32, n, n) for t in 1:T]
5+
model = RNN(n, n)
6+
println("CPU n=$n, t=$T")
7+
run_benchmark(model, x, cuda=false)
8+
println("CUDA n=$n, t=$T")
9+
run_benchmark(model, x, cuda=true)
10+
end
11+
12+
13+

perf/runbenchmarks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ include("conv.jl")
1111

1212
@info "Benchmark VGG"
1313
include("vgg.jl")
14+
15+
@info "Benchmark Recurrent"
16+
include("recurrent.jl")

0 commit comments

Comments
 (0)