@@ -3,6 +3,7 @@ gate(h, n) = (1:h) .+ h*(n-1)
3
3
gate (x:: AbstractVector , h, n) = @view x[gate (h,n)]
4
4
gate (x:: AbstractMatrix , h, n) = view (x, gate (h,n), :)
5
5
6
+ # AD-friendly helper for dividing monolithic RNN params into equally sized gates
6
7
multigate (x:: AbstractArray , h, :: Val{N} ) where N = ntuple (n -> gate (x,h,n), N)
7
8
8
9
@adjoint function multigate (x:: AbstractArray , h, c)
@@ -16,6 +17,8 @@ multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
16
17
return multigate (x, h, c), multigate_pullback
17
18
end
18
19
20
+ reshape_cell_output (h, x) = reshape (h, :, size (x)[2 : end ]. .. )
21
+
19
22
# Stateful recurrence
20
23
21
24
"""
@@ -116,8 +119,7 @@ RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zero
116
119
function (m:: RNNCell{F,A,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {F,A,V,T}
117
120
σ, Wi, Wh, b = m. σ, m. Wi, m. Wh, m. b
118
121
h = σ .(Wi* x .+ Wh* h .+ b)
119
- sz = size (x)
120
- return h, reshape (h, :, sz[2 : end ]. .. )
122
+ return h, reshape_cell_output (h, x)
121
123
end
122
124
123
125
@functor RNNCell
@@ -171,10 +173,9 @@ function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{Abstr
171
173
b, o = m. b, size (h, 1 )
172
174
g = m. Wi* x .+ m. Wh* h .+ b
173
175
input, forget, cell, output = multigate (g, o, Val (4 ))
174
- c = @. σ (forget) * c + σ (input) * tanh (cell)
175
- h′ = @. σ (output) * tanh (c)
176
- sz = size (x)
177
- return (h′, c), reshape (h′, :, sz[2 : end ]. .. )
176
+ c′ = @. σ (forget) * c + σ (input) * tanh (cell)
177
+ h′ = @. σ (output) * tanh (c′)
178
+ return (h′, c′), reshape_cell_output (h′, x)
178
179
end
179
180
180
181
@functor LSTMCell
@@ -235,8 +236,7 @@ function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},O
235
236
r, z = _gru_output (gxs, ghs, bs)
236
237
h̃ = @. tanh (gxs[3 ] + r * ghs[3 ] + bs[3 ])
237
238
h′ = @. (1 - z) * h̃ + z * h
238
- sz = size (x)
239
- return h′, reshape (h′, :, sz[2 : end ]. .. )
239
+ return h′, reshape_cell_output (h′, x)
240
240
end
241
241
242
242
@functor GRUCell
@@ -290,8 +290,7 @@ function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T}
290
290
r, z = _gru_output (gxs, ghs, bs)
291
291
h̃ = tanh .(gxs[3 ] .+ (Wh_h̃ * (r .* h)) .+ bs[3 ])
292
292
h′ = @. (1 - z) * h̃ + z * h
293
- sz = size (x)
294
- return h′, reshape (h′, :, sz[2 : end ]. .. )
293
+ return h′, reshape_cell_output (h′, x)
295
294
end
296
295
297
296
@functor GRUv3Cell
0 commit comments