Skip to content

Commit 2e0bb1d

Browse files
committed
rename intermediate variable and extract reshape helper
1 parent a6c9c11 commit 2e0bb1d

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

src/layers/recurrent.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ gate(h, n) = (1:h) .+ h*(n-1)
33
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
44
gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
55

6+
# AD-friendly helper for dividing monolithic RNN params into equally sized gates
67
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
78

89
@adjoint function multigate(x::AbstractArray, h, c)
@@ -16,6 +17,8 @@ multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
1617
return multigate(x, h, c), multigate_pullback
1718
end
1819

20+
reshape_cell_output(h, x) = reshape(h, :, size(x)[2:end]...)
21+
1922
# Stateful recurrence
2023

2124
"""
@@ -116,8 +119,7 @@ RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zero
116119
function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
117120
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
118121
h = σ.(Wi*x .+ Wh*h .+ b)
119-
sz = size(x)
120-
return h, reshape(h, :, sz[2:end]...)
122+
return h, reshape_cell_output(h, x)
121123
end
122124

123125
@functor RNNCell
@@ -171,10 +173,9 @@ function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{Abstr
171173
b, o = m.b, size(h, 1)
172174
g = m.Wi*x .+ m.Wh*h .+ b
173175
input, forget, cell, output = multigate(g, o, Val(4))
174-
c = @. σ(forget) * c + σ(input) * tanh(cell)
175-
h′ = @. σ(output) * tanh(c)
176-
sz = size(x)
177-
return (h′, c), reshape(h′, :, sz[2:end]...)
176+
c′ = @. σ(forget) * c + σ(input) * tanh(cell)
177+
h′ = @. σ(output) * tanh(c′)
178+
return (h′, c′), reshape_cell_output(h′, x)
178179
end
179180

180181
@functor LSTMCell
@@ -235,8 +236,7 @@ function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},O
235236
r, z = _gru_output(gxs, ghs, bs)
236237
= @. tanh(gxs[3] + r * ghs[3] + bs[3])
237238
h′ = @. (1 - z) *+ z * h
238-
sz = size(x)
239-
return h′, reshape(h′, :, sz[2:end]...)
239+
return h′, reshape_cell_output(h′, x)
240240
end
241241

242242
@functor GRUCell
@@ -290,8 +290,7 @@ function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T}
290290
r, z = _gru_output(gxs, ghs, bs)
291291
= tanh.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
292292
h′ = @. (1 - z) *+ z * h
293-
sz = size(x)
294-
return h′, reshape(h′, :, sz[2:end]...)
293+
return h′, reshape_cell_output(h′, x)
295294
end
296295

297296
@functor GRUv3Cell

0 commit comments

Comments
 (0)