Description
In the following minimal RNN example, first call to Flux.train!
fails with CUDNN_STATUS_BAD_PARAM
. Same error raised either cell is GRU
, RNN
or LSTM
.
using Flux
using Statistics
rnn = Chain(GRU(16, 8),
Dense(8,1, σ),
x -> reshape(x,:))
X = [rand(16,10) for i in 1:20]
Y = rand(10,20) ./ 10
rnn = rnn |> gpu
X = gpu(X)
Y = gpu(Y)
θ = Flux.params(rnn)
loss(x,y) = mean((Flux.stack(rnn.(X),2) .- y) .^ 2f0)
opt = ADAM(1e-3)
size(rnn[1].state)
Flux.reset!(rnn)
size(rnn[1].state)
Flux.train!(loss, θ, [(X,Y)], opt)
size(rnn[1].state)
loss(X,Y)
It can be observed that both prior and after reset!, rnn state is of size (8), while after a call to train! on GPU, state becomes the expected proper size (8,10). After each call to reset!
, the CUDNN_STATUS_BAD_PARAM error pops out after first call to train!, but subsequent ones are fine as the state size stays (8,10). Can't confirm whether that state size is the root cause, but appears closely tied to the bug. Also, a call to loss(X,Y)
results in a proper state dimension of (8,10). Running on CPU doesn't result in any error/warning.
Pkg info (same error also raised wih latest Zygote release):
[4f1ea46c] AWSCore v0.6.9
[1c724243] AWSS3 v0.6.10
[fbb218c0] BSON v0.2.5
[336ed68f] CSV v0.6.1
[3a865a2d] CuArrays v2.0.1
[a93c6f00] DataFrames v0.20.2
[587475ba] Flux v0.10.4
[28b8d3ca] GR v0.48.0
[91a5bcdd] Plots v1.0.5
[2913bbd2] StatsBase v0.32.2
[e88e6eb3] Zygote v0.4.15 #master (https://github.com/FluxML/Zygote.jl.git)
[10745b16] Statistics
CUDA: 10.1.168
CUDNN: 7.6.5