Skip to content

Commit dc6f286

Browse files
authored
Merge pull request #1871 from mkschleg/rnn_benchmarks
Recurrent benchmarks
2 parents a851436 + 9d1eb8c commit dc6f286

File tree

3 files changed

+70
-4
lines changed

3 files changed

+70
-4
lines changed

perf/bench_utils.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
using BenchmarkTools
22
using Flux
33
using CUDA
4-
using Zygote: pullback
4+
using Zygote: pullback, ignore
55

66

77
fw(m, x) = m(x)
88
bw(back) = back(1f0)
9-
fwbw(m, ps, x) = gradient(() -> sum(m(x)), ps)
10-
9+
fwbw(m, ps, x) = gradient(() -> sum(fw(m, x)), ps)
10+
pb(m, ps, x) = pullback(() -> sum(fw(m, x)), ps)
11+
1112
function run_benchmark(model, x; cuda=true)
1213

1314
if cuda
@@ -16,7 +17,7 @@ function run_benchmark(model, x; cuda=true)
1617
end
1718

1819
ps = Flux.params(model)
19-
y, back = pullback(() -> sum(model(x)), ps)
20+
y, back = pb(model, ps, x)
2021

2122

2223
if cuda

perf/recurrent.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
2+
3+
struct RNNWrapper{T}
4+
rnn::T
5+
end
6+
Flux.@functor RNNWrapper
7+
8+
# Need to specialize for RNNWrapper.
9+
fw(r::RNNWrapper, X::Vector{<:AbstractArray}) = begin
10+
Flux.reset!(r.rnn)
11+
[r.rnn(x) for x in X]
12+
end
13+
14+
fw(r::RNNWrapper, X) = begin
15+
Flux.reset!(r.rnn)
16+
r.rnn(X)
17+
end
18+
19+
fwbw(r::RNNWrapper, ps, X::Vector{<:AbstractArray}) = gradient(ps) do
20+
y = fw(r, X)
21+
sum(sum(y))
22+
end
23+
24+
pb(r::RNNWrapper, ps, X::Vector{<:AbstractArray}) = pullback(ps) do
25+
y = fw(r, X)
26+
sum(sum(y))
27+
end
28+
29+
function rnn_benchmark_sweep(data_creator::Function, rnn_type)
30+
for n in [2, 20, 200, 1000], ts in [1, 4, 16, 64]
31+
x, x_n = data_creator(n, ts)
32+
model = RNNWrapper(rnn_type(n, n))
33+
34+
println("$rnn_type $x_n CPU n=$n, ts=$ts")
35+
run_benchmark(model, x, cuda=false)
36+
37+
println("$rnn_type $x_n CUDA n=$n, ts=$ts")
38+
try
39+
run_benchmark(model, x, cuda=true)
40+
catch ex
41+
@show typeof(ex)
42+
if ex isa OutOfGPUMemoryError
43+
@warn "Not enough GPU memory to run test"
44+
else
45+
rethrow(ex)
46+
end
47+
end
48+
end
49+
end
50+
51+
for rnn_type in [Flux.RNN, Flux.GRU, Flux.LSTM]
52+
rnn_benchmark_sweep(rnn_type) do n, ts
53+
[randn(Float32, n, n) for _ in 1:ts], "Vec"
54+
end
55+
end
56+
57+
for rnn_type in [Flux.RNN, Flux.GRU, Flux.LSTM]
58+
rnn_benchmark_sweep(rnn_type) do n, ts
59+
randn(Float32, n, n, ts), "Block"
60+
end
61+
end
62+

perf/runbenchmarks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ include("conv.jl")
1111

1212
@info "Benchmark VGG"
1313
include("vgg.jl")
14+
15+
@info "Benchmark Recurrent"
16+
include("recurrent.jl")

0 commit comments

Comments
 (0)