Skip to content

Commit 3a32ae6

Browse files
committed
Added benchmarks for 3d rnn api.
1 parent da2455f commit 3a32ae6

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

perf/bench_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function run_benchmark(model, x; cuda=true)
2828
end
2929

3030
ps = Flux.params(model)
31-
y, back = if model isa Flux.Recur
31+
y, back = if model isa Flux.Recur && eltype(x) isa AbstractVector
3232
pullback(() -> sum(sum([model(x_t) for x_t in x])), ps)
3333
else
3434
pullback(() -> sum(model(x)), ps)

perf/recurrent.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,15 @@ for n in [2, 20, 200, 2000], T in [1, 8, 16, 64]
99
run_benchmark(model, x, cuda=true)
1010
end
1111

12+
println("RNN-3d")
13+
for n in [2, 20, 200, 2000], T in [1, 8, 16, 64]
14+
x = randn(Float32, n, n, T)
15+
model = RNN(n, n)
16+
println("CPU n=$n, t=$T")
17+
run_benchmark(model, x, cuda=false)
18+
println("CUDA n=$n, t=$T")
19+
run_benchmark(model, x, cuda=true)
20+
end
21+
1222

1323

0 commit comments

Comments
 (0)