Skip to content

RNN on GPU fails on first backward call #1114

Closed
JuliaGPU/CuArrays.jl
#706
@jeremiedb

Description

@jeremiedb

In the following minimal RNN example, first call to Flux.train! fails with CUDNN_STATUS_BAD_PARAM. Same error raised either cell is GRU, RNN or LSTM.

using Flux
using Statistics

rnn = Chain(GRU(16, 8),
  Dense(8,1, σ),
  x -> reshape(x,:))

X = [rand(16,10) for i in 1:20]
Y = rand(10,20) ./ 10

rnn = rnn |> gpu
X = gpu(X)
Y = gpu(Y)

θ = Flux.params(rnn)
loss(x,y) = mean((Flux.stack(rnn.(X),2) .- y) .^ 2f0)
opt = ADAM(1e-3)
size(rnn[1].state)
Flux.reset!(rnn)
size(rnn[1].state)
Flux.train!(loss, θ, [(X,Y)], opt)
size(rnn[1].state)
loss(X,Y)

It can be observed that both prior and after reset!, rnn state is of size (8), while after a call to train! on GPU, state becomes the expected proper size (8,10). After each call to reset!, the CUDNN_STATUS_BAD_PARAM error pops out after first call to train!, but subsequent ones are fine as the state size stays (8,10). Can't confirm whether that state size is the root cause, but appears closely tied to the bug. Also, a call to loss(X,Y) results in a proper state dimension of (8,10). Running on CPU doesn't result in any error/warning.

Pkg info (same error also raised wih latest Zygote release):

  [4f1ea46c] AWSCore v0.6.9
  [1c724243] AWSS3 v0.6.10
  [fbb218c0] BSON v0.2.5
  [336ed68f] CSV v0.6.1
  [3a865a2d] CuArrays v2.0.1
  [a93c6f00] DataFrames v0.20.2
  [587475ba] Flux v0.10.4
  [28b8d3ca] GR v0.48.0
  [91a5bcdd] Plots v1.0.5
  [2913bbd2] StatsBase v0.32.2
  [e88e6eb3] Zygote v0.4.15 #master (https://github.com/FluxML/Zygote.jl.git)
  [10745b16] Statistics

CUDA: 10.1.168
CUDNN: 7.6.5

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions