@@ -189,18 +189,18 @@ end
189
189
190
190
# Vanilla RNN
191
191
192
- struct RNNCell{F,A ,V,S}
192
+ struct RNNCell{F,I,H ,V,S}
193
193
σ:: F
194
- Wi:: A
195
- Wh:: A
194
+ Wi:: I
195
+ Wh:: H
196
196
b:: V
197
197
state0:: S
198
198
end
199
199
200
200
RNNCell ((in, out):: Pair , σ= tanh; init= Flux. glorot_uniform, initb= zeros32, init_state= zeros32) =
201
201
RNNCell (σ, init (out, in), init (out, out), initb (out), init_state (out,1 ))
202
202
203
- function (m:: RNNCell{F,A, V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {F,A ,V,T}
203
+ function (m:: RNNCell{F,I,H, V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {F,I,H ,V,T}
204
204
Wi, Wh, b = m. Wi, m. Wh, m. b
205
205
σ = NNlib. fast_act (m. σ, x)
206
206
h = σ .(Wi* x .+ Wh* h .+ b)
@@ -271,15 +271,27 @@ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
271
271
julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
272
272
(50,)
273
273
```
274
+
275
+ # Note:
276
+ `RNNCell`s can be constructed directly by specifying the non-linear function, the `W_i` and `W_h` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `W_i` and `W_h` matrices do not need to be the same type, but if `W_h` is `dxd`, then `W_i` should be of shape `dxN`.
277
+
278
+ ```julia
279
+ julia> using LinearAlgebra
280
+
281
+ julia> r = Flux.Recur(Flux.RNNCell(tanh, rand(5, 4), Tridiagonal(rand(5, 5)), rand(5), rand(5, 1)))
282
+
283
+ julia> r(rand(4, 10)) |> size # batch size of 10
284
+ (5, 10)
285
+ ````
274
286
"""
275
287
RNN (a... ; ka... ) = Recur (RNNCell (a... ; ka... ))
276
288
Recur (m:: RNNCell ) = Recur (m, m. state0)
277
289
278
290
# LSTM
279
291
280
- struct LSTMCell{A ,V,S}
281
- Wi:: A
282
- Wh:: A
292
+ struct LSTMCell{I,H ,V,S}
293
+ Wi:: I
294
+ Wh:: H
283
295
b:: V
284
296
state0:: S
285
297
end
@@ -293,7 +305,7 @@ function LSTMCell((in, out)::Pair;
293
305
return cell
294
306
end
295
307
296
- function (m:: LSTMCell{A, V,<:NTuple{2,AbstractMatrix{T}}} )((h, c), x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A ,V,T}
308
+ function (m:: LSTMCell{I,H, V,<:NTuple{2,AbstractMatrix{T}}} )((h, c), x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {I,H ,V,T}
297
309
b, o = m. b, size (h, 1 )
298
310
g = muladd (m. Wi, x, muladd (m. Wh, h, b))
299
311
input, forget, cell, output = multigate (g, o, Val (4 ))
@@ -351,17 +363,17 @@ function _gru_output(gxs, ghs, bs)
351
363
return r, z
352
364
end
353
365
354
- struct GRUCell{A ,V,S}
355
- Wi:: A
356
- Wh:: A
366
+ struct GRUCell{I,H ,V,S}
367
+ Wi:: I
368
+ Wh:: H
357
369
b:: V
358
370
state0:: S
359
371
end
360
372
361
373
GRUCell ((in, out):: Pair ; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
362
374
GRUCell (init (out * 3 , in), init (out * 3 , out), initb (out * 3 ), init_state (out,1 ))
363
375
364
- function (m:: GRUCell{A, V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A ,V,T}
376
+ function (m:: GRUCell{I,H, V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {I,H ,V,T}
365
377
Wi, Wh, b, o = m. Wi, m. Wh, m. b, size (h, 1 )
366
378
gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (3 )), multigate (b, o, Val (3 ))
367
379
r, z = _gru_output (gxs, ghs, bs)
@@ -414,19 +426,19 @@ Recur(m::GRUCell) = Recur(m, m.state0)
414
426
415
427
# GRU v3
416
428
417
- struct GRUv3Cell{A,V ,S}
418
- Wi:: A
419
- Wh:: A
429
+ struct GRUv3Cell{I,H,V,HH ,S}
430
+ Wi:: I
431
+ Wh:: H
420
432
b:: V
421
- Wh_h̃:: A
433
+ Wh_h̃:: HH
422
434
state0:: S
423
435
end
424
436
425
437
GRUv3Cell ((in, out):: Pair ; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
426
438
GRUv3Cell (init (out * 3 , in), init (out * 2 , out), initb (out * 3 ),
427
439
init (out, out), init_state (out,1 ))
428
440
429
- function (m:: GRUv3Cell{A,V, <:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A,V ,T}
441
+ function (m:: GRUv3Cell{I,H,V,HH, <:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {I,H,V,HH ,T}
430
442
Wi, Wh, b, Wh_h̃, o = m. Wi, m. Wh, m. b, m. Wh_h̃, size (h, 1 )
431
443
gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (2 )), multigate (b, o, Val (3 ))
432
444
r, z = _gru_output (gxs, ghs, bs)
0 commit comments