diff --git a/src/dnn/rnn.jl b/src/dnn/rnn.jl index f9b7ecf6..d9bf45db 100644 --- a/src/dnn/rnn.jl +++ b/src/dnn/rnn.jl @@ -186,7 +186,7 @@ function pullback(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}) where T <: Unio reserve, (y, ho) = CUDNN.forwardTrain(rnn, x, h) return (y, ho), function (dy, dho) h_ = CUDNN.hBatch(x, h) - dx, dh = CUDNN.backwardData(rnn, y, dy, dho, h_, reserve) + dx, dh = CUDNN.backwardData(rnn, y, dy, isnothing(dho) ? dho : CUDNN.hBatch(y,dho), h_, reserve) (dWi, dWh), db = CUDNN.backwardWeights(rnn, x, h_, y, reserve) return (x = dx, h = dh, Wi = dWi, Wh = dWh, b = db) end @@ -197,7 +197,7 @@ function pullback(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c::CuArray{T}) return (y, ho, co), function (dy, dho, dco) h_ = CUDNN.hBatch(x, h) c_ = CUDNN.hBatch(x, c) - dx, dh, dc = CUDNN.backwardData(rnn, y, dy, dho, dco, h_, c_, reserve) + dx, dh, dc = CUDNN.backwardData(rnn, y, dy, isnothing(dho) ? dho : CUDNN.hBatch(y,dho), dco, h_, c_, reserve) (dWi, dWh), db = CUDNN.backwardWeights(rnn, x, h_, y, reserve) return (x = dx, h = dh, c = dc, Wi = dWi, Wh = dWh, b = db) end