@@ -3,6 +3,19 @@ gate(h, n) = (1:h) .+ h*(n-1)
3
3
gate (x:: AbstractVector , h, n) = @view x[gate (h,n)]
4
4
gate (x:: AbstractMatrix , h, n) = view (x, gate (h,n), :)
5
5
6
+ multigate (x:: AbstractArray , h, :: Val{N} ) where N = ntuple (n -> gate (x,h,n), N)
7
+
8
+ @adjoint function multigate (x:: AbstractArray , h, c)
9
+ function multigate_pullback (dy)
10
+ dx = Zygote. _zero (x, eltype (x))
11
+ map (multigate (dx, h, c), dy) do dxᵢ, dyᵢ
12
+ dyᵢ != = nothing && (dxᵢ.= Zygote. accum .(dxᵢ, dyᵢ));
13
+ end
14
+ return (dx, nothing , nothing )
15
+ end
16
+ return multigate (x, h, c), multigate_pullback
17
+ end
18
+
6
19
# Stateful recurrence
7
20
8
21
"""
157
170
function (m:: LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}} )((h, c), x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A,V,T}
158
171
b, o = m. b, size (h, 1 )
159
172
g = m. Wi* x .+ m. Wh* h .+ b
160
- input = σ .(gate (g, o, 1 ))
161
- forget = σ .(gate (g, o, 2 ))
162
- cell = tanh .(gate (g, o, 3 ))
163
- output = σ .(gate (g, o, 4 ))
164
- c = forget .* c .+ input .* cell
165
- h′ = output .* tanh .(c)
173
+ input, forget, cell, output = multigate (g, o, Val (4 ))
174
+ c = @. σ (forget) * c + σ (input) * tanh (cell)
175
+ h′ = @. σ (output) * tanh (c)
166
176
sz = size (x)
167
177
return (h′, c), reshape (h′, :, sz[2 : end ]. .. )
168
178
end
@@ -203,13 +213,10 @@ end
203
213
204
214
# GRU
205
215
206
- function _gru_output (Wi, Wh, b, x, h)
207
- o = size (h, 1 )
208
- gx, gh = Wi* x, Wh* h
209
- r = σ .(gate (gx, o, 1 ) .+ gate (gh, o, 1 ) .+ gate (b, o, 1 ))
210
- z = σ .(gate (gx, o, 2 ) .+ gate (gh, o, 2 ) .+ gate (b, o, 2 ))
211
-
212
- return gx, gh, r, z
216
+ function _gru_output (gxs, ghs, bs)
217
+ r = @. σ (gxs[1 ] + ghs[1 ] + bs[1 ])
218
+ z = @. σ (gxs[2 ] + ghs[2 ] + bs[2 ])
219
+ return r, z
213
220
end
214
221
215
222
struct GRUCell{A,V,S}
@@ -223,10 +230,11 @@ GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
223
230
GRUCell (init (out * 3 , in), init (out * 3 , out), initb (out * 3 ), init_state (out,1 ))
224
231
225
232
function (m:: GRUCell{A,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A,V,T}
226
- b, o = m. b, size (h, 1 )
227
- gx, gh, r, z = _gru_output (m. Wi, m. Wh, b, x, h)
228
- h̃ = tanh .(gate (gx, o, 3 ) .+ r .* gate (gh, o, 3 ) .+ gate (b, o, 3 ))
229
- h′ = (1 .- z) .* h̃ .+ z .* h
233
+ Wi, Wh, b, o = m. Wi, m. Wh, m. b, size (h, 1 )
234
+ gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (3 )), multigate (b, o, Val (3 ))
235
+ r, z = _gru_output (gxs, ghs, bs)
236
+ h̃ = @. tanh (gxs[3 ] + r * ghs[3 ] + bs[3 ])
237
+ h′ = @. (1 - z) * h̃ + z * h
230
238
sz = size (x)
231
239
return h′, reshape (h′, :, sz[2 : end ]. .. )
232
240
end
@@ -277,10 +285,11 @@ GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32)
277
285
init (out, out), init_state (out,1 ))
278
286
279
287
function (m:: GRUv3Cell{A,V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A,V,T}
280
- b, o = m. b, size (h, 1 )
281
- gx, gh, r, z = _gru_output (m. Wi, m. Wh, b, x, h)
282
- h̃ = tanh .(gate (gx, o, 3 ) .+ (m. Wh_h̃ * (r .* h)) .+ gate (b, o, 3 ))
283
- h′ = (1 .- z) .* h̃ .+ z .* h
288
+ Wi, Wh, b, Wh_h̃, o = m. Wi, m. Wh, m. b, m. Wh_h̃, size (h, 1 )
289
+ gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (2 )), multigate (b, o, Val (3 ))
290
+ r, z = _gru_output (gxs, ghs, bs)
291
+ h̃ = tanh .(gxs[3 ] .+ (Wh_h̃ * (r .* h)) .+ bs[3 ])
292
+ h′ = @. (1 - z) * h̃ + z * h
284
293
sz = size (x)
285
294
return h′, reshape (h′, :, sz[2 : end ]. .. )
286
295
end
0 commit comments