Skip to content

Commit add28e3

Browse files
committed
Updated recurrent benchmarks from suggestions. Modified benchutils to
be easier to overload behaviour.
1 parent 1525b30 commit add28e3

File tree

2 files changed

+53
-47
lines changed

2 files changed

+53
-47
lines changed

perf/bench_utils.jl

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,9 @@ using Zygote: pullback, ignore
66

77
fw(m, x) = m(x)
88
bw(back) = back(1f0)
9-
fwbw(m, ps, x) = gradient(() -> sum(m(x)), ps)
9+
fwbw(m, ps, x) = gradient(() -> sum(fw(m, x)), ps)
10+
pb(m, ps, x) = pullback(()->sum(fw(m, x)), ps)
1011

11-
# Need to specialize for flux.recur.
12-
fw(m::Flux.Recur, X::Vector{<:AbstractArray}) = begin
13-
ignore() do
14-
Flux.reset!(m)
15-
end
16-
[m(x) for x in X]
17-
end
18-
fwbw(m::Flux.Recur, ps, X::Vector{<:AbstractArray}) = gradient(ps) do
19-
y = fw(m, X)
20-
sum(sum(y))
21-
end
22-
2312
function run_benchmark(model, x; cuda=true)
2413

2514
if cuda
@@ -28,11 +17,7 @@ function run_benchmark(model, x; cuda=true)
2817
end
2918

3019
ps = Flux.params(model)
31-
y, back = if model isa Flux.Recur && eltype(x) <: AbstractArray
32-
pullback(() -> sum(sum([model(x_t) for x_t in x])), ps)
33-
else
34-
pullback(() -> sum(model(x)), ps)
35-
end
20+
y, back = pb(model, ps, x)
3621

3722

3823
if cuda

perf/recurrent.jl

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,62 @@
11

2-
println("RNN")
3-
for n in [2, 20, 200, 1000], T in [1, 8, 16, 64]
4-
x = [randn(Float32, n, n) for t in 1:T]
5-
model = RNN(n, n)
6-
println("CPU n=$n, t=$T")
7-
run_benchmark(model, x, cuda=false)
8-
println("CUDA n=$n, t=$T")
9-
try
10-
run_benchmark(model, x, cuda=true)
11-
catch ex
12-
@show typeof(ex)
13-
if ex isa OutOfGPUMemoryError
14-
@warn "Not enough GPU memory to run test"
15-
else
16-
rethrow(ex)
17-
end
18-
end
2+
3+
struct RNNWrapper{T}
4+
rnn::T
5+
end
6+
Flux.@functor RNNWrapper
7+
8+
# Need to specialize for RNNWrapper.
9+
fw(r::RNNWrapper, X::Vector{<:AbstractArray}) = begin
10+
Flux.reset!(r.rnn)
11+
[r.rnn(x) for x in X]
12+
end
13+
14+
fw(r::RNNWrapper, X) = begin
15+
Flux.reset!(r.rnn)
16+
r.rnn(X)
17+
end
18+
19+
fwbw(r::RNNWrapper, ps, X::Vector{<:AbstractArray}) = gradient(ps) do
20+
y = fw(r, X)
21+
sum(sum(y))
22+
end
23+
24+
pb(r::RNNWrapper, ps, X::Vector{<:AbstractArray}) = pullback(ps) do
25+
y = fw(r, X)
26+
sum(sum(y))
1927
end
2028

21-
println("RNN-3d")
22-
for n in [2, 20, 200, 1000], T in [1, 8, 16, 64]
23-
x = randn(Float32, n, n, T)
24-
model = RNN(n, n)
25-
println("CPU n=$n, t=$T")
26-
run_benchmark(model, x, cuda=false)
27-
println("CUDA n=$n, t=$T")
28-
try
29+
function rnn_benchmark_sweep(data_creator::Function, rnn_type)
30+
for n in [2, 20, 200, 1000], ts in [1, 4, 16, 64]
31+
x, x_n = data_creator(n, ts)
32+
model = RNNWrapper(rnn_type(n, n))
33+
34+
println("$rnn_type $x_n CPU n=$n, ts=$ts")
35+
run_benchmark(model, x, cuda=false)
36+
37+
println("$rnn_type $x_n CUDA n=$n, ts=$ts")
38+
try
2939
run_benchmark(model, x, cuda=true)
30-
catch ex
40+
catch ex
3141
@show typeof(ex)
3242
if ex isa OutOfGPUMemoryError
33-
@warn "Not enough GPU memory to run test"
43+
@warn "Not enough GPU memory to run test"
3444
else
35-
rethrow(ex)
45+
rethrow(ex)
3646
end
37-
end
47+
end
48+
end
3849
end
3950

51+
for rnn_type in [Flux.RNN, Flux.GRU, Flux.LSTM]
52+
rnn_benchmark_sweep(rnn_type) do n, ts
53+
[randn(Float32, n, n) for _ in 1:ts], "Vec"
54+
end
55+
end
4056

57+
for rnn_type in [Flux.RNN, Flux.GRU, Flux.LSTM]
58+
rnn_benchmark_sweep(rnn_type) do n, ts
59+
randn(Float32, n, n, ts), "Block"
60+
end
61+
end
4162

0 commit comments

Comments
 (0)