1
1
2
2
gate (h, n) = (1 : h) .+ h* (n- 1 )
3
3
gate (x:: AbstractVector , h, n) = @view x[gate (h,n)]
4
- gate (x:: AbstractMatrix , h, n) = x[gate (h,n),:]
4
+ gate (x:: AbstractMatrix , h, n) = view (x, gate (h,n), :)
5
+
6
+ # AD-friendly helper for dividing monolithic RNN params into equally sized gates
7
+ multigate (x:: AbstractArray , h, :: Val{N} ) where N = ntuple (n -> gate (x,h,n), N)
8
+
9
+ @adjoint function multigate (x:: AbstractArray , h, c)
10
+ function multigate_pullback (dy)
11
+ dx = Zygote. _zero (x, eltype (x))
12
+ map (multigate (dx, h, c), dy) do dxᵢ, dyᵢ
13
+ dyᵢ != = nothing && (dxᵢ.= Zygote. accum .(dxᵢ, dyᵢ));
14
+ end
15
+ return (dx, nothing , nothing )
16
+ end
17
+ return multigate (x, h, c), multigate_pullback
18
+ end
19
+
20
+ reshape_cell_output (h, x) = reshape (h, :, size (x)[2 : end ]. .. )
5
21
6
22
# Stateful recurrence
7
23
@@ -97,14 +113,13 @@ struct RNNCell{F,A,V,S}
97
113
state0:: S
98
114
end
99
115
100
- RNNCell (in:: Integer , out:: Integer , σ= tanh; init= Flux. glorot_uniform, initb= zeros32, init_state= zeros32) =
116
+ RNNCell (in:: Integer , out:: Integer , σ= tanh; init= Flux. glorot_uniform, initb= zeros32, init_state= zeros32) =
101
117
RNNCell (σ, init (out, in), init (out, out), initb (out), init_state (out,1 ))
102
118
103
119
function (m:: RNNCell{F,A,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {F,A,V,T}
104
120
σ, Wi, Wh, b = m. σ, m. Wi, m. Wh, m. b
105
121
h = σ .(Wi* x .+ Wh* h .+ b)
106
- sz = size (x)
107
- return h, reshape (h, :, sz[2 : end ]. .. )
122
+ return h, reshape_cell_output (h, x)
108
123
end
109
124
110
125
@functor RNNCell
@@ -208,14 +223,10 @@ end
208
223
function (m:: LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}} )((h, c), x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A,V,T}
209
224
b, o = m. b, size (h, 1 )
210
225
g = m. Wi* x .+ m. Wh* h .+ b
211
- input = σ .(gate (g, o, 1 ))
212
- forget = σ .(gate (g, o, 2 ))
213
- cell = tanh .(gate (g, o, 3 ))
214
- output = σ .(gate (g, o, 4 ))
215
- c = forget .* c .+ input .* cell
216
- h′ = output .* tanh .(c)
217
- sz = size (x)
218
- return (h′, c), reshape (h′, :, sz[2 : end ]. .. )
226
+ input, forget, cell, output = multigate (g, o, Val (4 ))
227
+ c′ = @. σ (forget) * c + σ (input) * tanh (cell)
228
+ h′ = @. σ (output) * tanh (c′)
229
+ return (h′, c′), reshape_cell_output (h′, x)
219
230
end
220
231
221
232
@functor LSTMCell
@@ -269,7 +280,7 @@ function Base.getproperty(m::LSTMCell, sym::Symbol)
269
280
elseif sym === :c
270
281
Zygote. ignore () do
271
282
@warn " LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
272
- end
283
+ end
273
284
return getfield (m, :state0 )[2 ]
274
285
else
275
286
return getfield (m, sym)
@@ -278,13 +289,10 @@ end
278
289
279
290
# GRU
280
291
281
- function _gru_output (Wi, Wh, b, x, h)
282
- o = size (h, 1 )
283
- gx, gh = Wi* x, Wh* h
284
- r = σ .(gate (gx, o, 1 ) .+ gate (gh, o, 1 ) .+ gate (b, o, 1 ))
285
- z = σ .(gate (gx, o, 2 ) .+ gate (gh, o, 2 ) .+ gate (b, o, 2 ))
286
-
287
- return gx, gh, r, z
292
+ function _gru_output (gxs, ghs, bs)
293
+ r = @. σ (gxs[1 ] + ghs[1 ] + bs[1 ])
294
+ z = @. σ (gxs[2 ] + ghs[2 ] + bs[2 ])
295
+ return r, z
288
296
end
289
297
290
298
struct GRUCell{A,V,S}
@@ -298,12 +306,12 @@ GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
298
306
GRUCell (init (out * 3 , in), init (out * 3 , out), initb (out * 3 ), init_state (out,1 ))
299
307
300
308
function (m:: GRUCell{A,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A,V,T}
301
- b, o = m. b, size (h, 1 )
302
- gx, gh, r, z = _gru_output (m . Wi, m . Wh, b, x, h )
303
- h̃ = tanh .( gate (gx, o, 3 ) .+ r .* gate (gh, o, 3 ) .+ gate (b, o, 3 ) )
304
- h′ = ( 1 .- z) .* h̃ .+ z .* h
305
- sz = size (x)
306
- return h′, reshape (h′, :, sz[ 2 : end ] . .. )
309
+ Wi, Wh, b, o = m . Wi, m . Wh, m. b, size (h, 1 )
310
+ gxs, ghs, bs = multigate (Wi * x, o, Val ( 3 )), multigate (Wh * h, o, Val ( 3 )), multigate ( b, o, Val ( 3 ) )
311
+ r, z = _gru_output (gxs, ghs, bs )
312
+ h̃ = @. tanh (gxs[ 3 ] + r * ghs[ 3 ] + bs[ 3 ])
313
+ h′ = @. ( 1 - z) * h̃ + z * h
314
+ return h′, reshape_cell_output (h′, x )
307
315
end
308
316
309
317
@functor GRUCell
@@ -372,16 +380,16 @@ struct GRUv3Cell{A,V,S}
372
380
end
373
381
374
382
GRUv3Cell (in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
375
- GRUv3Cell (init (out * 3 , in), init (out * 2 , out), initb (out * 3 ),
383
+ GRUv3Cell (init (out * 3 , in), init (out * 2 , out), initb (out * 3 ),
376
384
init (out, out), init_state (out,1 ))
377
385
378
386
function (m:: GRUv3Cell{A,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A,V,T}
379
- b, o = m. b , size (h, 1 )
380
- gx, gh, r, z = _gru_output (m . Wi, m . Wh, b, x, h )
381
- h̃ = tanh .( gate (gx, o, 3 ) .+ (m . Wh_h̃ * (r .* h)) .+ gate (b, o, 3 ) )
382
- h′ = ( 1 .- z) .* h̃ .+ z .* h
383
- sz = size (x)
384
- return h′, reshape (h′, :, sz[ 2 : end ] . .. )
387
+ Wi, Wh, b, Wh_h̃, o = m. Wi, m . Wh, m . b, m . Wh_h̃ , size (h, 1 )
388
+ gxs, ghs, bs = multigate (Wi * x, o, Val ( 3 )), multigate (Wh * h, o, Val ( 2 )), multigate ( b, o, Val ( 3 ) )
389
+ r, z = _gru_output (gxs, ghs, bs )
390
+ h̃ = tanh .(gxs[ 3 ] .+ (Wh_h̃ * (r .* h)) .+ bs[ 3 ])
391
+ h′ = @. ( 1 - z) * h̃ + z * h
392
+ return h′, reshape_cell_output (h′, x )
385
393
end
386
394
387
395
@functor GRUv3Cell
0 commit comments