From 167bbdf99f50e16f3ae25dc2836ada9f44551f34 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 5 May 2020 16:38:57 +0530 Subject: [PATCH 1/2] add hBatch --- src/dnn/rnn.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dnn/rnn.jl b/src/dnn/rnn.jl index f9b7ecf6..6fbb1013 100644 --- a/src/dnn/rnn.jl +++ b/src/dnn/rnn.jl @@ -101,6 +101,7 @@ end hBatch(x::AbstractVector, h::CuVector) = h hBatch(x::AbstractMatrix, h::CuVector) = h .*CuArrays.ones(1, size(x, 2)) hBatch(x::AbstractMatrix, h::CuMatrix) = h .*CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1) +hBatch(x, ::Nothing) = nothing function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T h = hBatch(x, h_) @@ -197,7 +198,7 @@ function pullback(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c::CuArray{T}) return (y, ho, co), function (dy, dho, dco) h_ = CUDNN.hBatch(x, h) c_ = CUDNN.hBatch(x, c) - dx, dh, dc = CUDNN.backwardData(rnn, y, dy, dho, dco, h_, c_, reserve) + dx, dh, dc = CUDNN.backwardData(rnn, y, dy, CUDNN.hBatch(y,dho), dco, h_, c_, reserve) (dWi, dWh), db = CUDNN.backwardWeights(rnn, x, h_, y, reserve) return (x = dx, h = dh, c = dc, Wi = dWi, Wh = dWh, b = db) end From a76f65adfae67b0b4d93737c656f7327cf131dd8 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 5 May 2020 16:48:56 +0530 Subject: [PATCH 2/2] check isnothing(dho) --- src/dnn/rnn.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dnn/rnn.jl b/src/dnn/rnn.jl index 6fbb1013..d9bf45db 100644 --- a/src/dnn/rnn.jl +++ b/src/dnn/rnn.jl @@ -101,7 +101,6 @@ end hBatch(x::AbstractVector, h::CuVector) = h hBatch(x::AbstractMatrix, h::CuVector) = h .*CuArrays.ones(1, size(x, 2)) hBatch(x::AbstractMatrix, h::CuMatrix) = h .*CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1) -hBatch(x, ::Nothing) = nothing function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T h = hBatch(x, h_) @@ -187,7 +186,7 @@ function pullback(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}) where T <: Unio reserve, (y, ho) = CUDNN.forwardTrain(rnn, x, h) return (y, ho), function (dy, dho) h_ = CUDNN.hBatch(x, h) - dx, dh = CUDNN.backwardData(rnn, y, dy, dho, h_, reserve) + dx, dh = CUDNN.backwardData(rnn, y, dy, isnothing(dho) ? dho : CUDNN.hBatch(y,dho), h_, reserve) (dWi, dWh), db = CUDNN.backwardWeights(rnn, x, h_, y, reserve) return (x = dx, h = dh, Wi = dWi, Wh = dWh, b = db) end @@ -198,7 +197,7 @@ function pullback(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c::CuArray{T}) return (y, ho, co), function (dy, dho, dco) h_ = CUDNN.hBatch(x, h) c_ = CUDNN.hBatch(x, c) - dx, dh, dc = CUDNN.backwardData(rnn, y, dy, CUDNN.hBatch(y,dho), dco, h_, c_, reserve) + dx, dh, dc = CUDNN.backwardData(rnn, y, dy, isnothing(dho) ? dho : CUDNN.hBatch(y,dho), dco, h_, c_, reserve) (dWi, dWh), db = CUDNN.backwardWeights(rnn, x, h_, y, reserve) return (x = dx, h = dh, c = dc, Wi = dWi, Wh = dWh, b = db) end