Skip to content

Commit 9a8a676

Browse files
authored
Merge pull request #2073 from mcognetta/RNN_relax_types
Relax `RNN`/`LSTM`/`GRUCell` internal matrix type restrictions
2 parents 1797f2a + 93071e2 commit 9a8a676

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

src/layers/recurrent.jl

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -189,18 +189,18 @@ end
189189

190190
# Vanilla RNN
191191

192-
struct RNNCell{F,A,V,S}
192+
struct RNNCell{F,I,H,V,S}
193193
σ::F
194-
Wi::A
195-
Wh::A
194+
Wi::I
195+
Wh::H
196196
b::V
197197
state0::S
198198
end
199199

200200
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
201201
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
202202

203-
function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
203+
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T}
204204
Wi, Wh, b = m.Wi, m.Wh, m.b
205205
σ = NNlib.fast_act(m.σ, x)
206206
h = σ.(Wi*x .+ Wh*h .+ b)
@@ -271,15 +271,27 @@ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
271271
julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
272272
(50,)
273273
```
274+
275+
# Note:
276+
`RNNCell`s can be constructed directly by specifying the non-linear function, the `W_i` and `W_h` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `W_i` and `W_h` matrices do not need to be the same type, but if `W_h` is `dxd`, then `W_i` should be of shape `dxN`.
277+
278+
```julia
279+
julia> using LinearAlgebra
280+
281+
julia> r = Flux.Recur(Flux.RNNCell(tanh, rand(5, 4), Tridiagonal(rand(5, 5)), rand(5), rand(5, 1)))
282+
283+
julia> r(rand(4, 10)) |> size # batch size of 10
284+
(5, 10)
285+
````
274286
"""
275287
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
276288
Recur(m::RNNCell) = Recur(m, m.state0)
277289

278290
# LSTM
279291

280-
struct LSTMCell{A,V,S}
281-
Wi::A
282-
Wh::A
292+
struct LSTMCell{I,H,V,S}
293+
Wi::I
294+
Wh::H
283295
b::V
284296
state0::S
285297
end
@@ -293,7 +305,7 @@ function LSTMCell((in, out)::Pair;
293305
return cell
294306
end
295307

296-
function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
308+
function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
297309
b, o = m.b, size(h, 1)
298310
g = muladd(m.Wi, x, muladd(m.Wh, h, b))
299311
input, forget, cell, output = multigate(g, o, Val(4))
@@ -351,17 +363,17 @@ function _gru_output(gxs, ghs, bs)
351363
return r, z
352364
end
353365

354-
struct GRUCell{A,V,S}
355-
Wi::A
356-
Wh::A
366+
struct GRUCell{I,H,V,S}
367+
Wi::I
368+
Wh::H
357369
b::V
358370
state0::S
359371
end
360372

361373
GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
362374
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
363375

364-
function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
376+
function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
365377
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
366378
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
367379
r, z = _gru_output(gxs, ghs, bs)
@@ -414,19 +426,19 @@ Recur(m::GRUCell) = Recur(m, m.state0)
414426

415427
# GRU v3
416428

417-
struct GRUv3Cell{A,V,S}
418-
Wi::A
419-
Wh::A
429+
struct GRUv3Cell{I,H,V,HH,S}
430+
Wi::I
431+
Wh::H
420432
b::V
421-
Wh_h̃::A
433+
Wh_h̃::HH
422434
state0::S
423435
end
424436

425437
GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
426438
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
427439
init(out, out), init_state(out,1))
428440

429-
function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
441+
function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,HH,T}
430442
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
431443
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
432444
r, z = _gru_output(gxs, ghs, bs)

test/layers/recurrent.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using LinearAlgebra
2+
13
# Ref FluxML/Flux.jl#1209 1D input
24
@testset "BPTT-1D" begin
35
seq = [rand(Float32, 2) for i = 1:3]
@@ -138,3 +140,32 @@ end
138140
x3,
139141
zeros(x_size[1:end-1]); dims=ndims(x))
140142
end
143+
144+
@testset "Different Internal Matrix Types" begin
145+
R = Flux.Recur(Flux.RNNCell(tanh, rand(5, 3), Tridiagonal(rand(5, 5)), rand(5), rand(5, 1)))
146+
# don't want to pull in SparseArrays just for this test, but there aren't any
147+
# non-square structured matrix types in LinearAlgebra. so we will use a different
148+
# eltype matrix, which would fail before when `W_i` and `W_h` were required to be the
149+
# same type.
150+
L = Flux.Recur(Flux.LSTMCell(rand(5*4, 3), rand(1:20, 5*4, 5), rand(5*4), (rand(5, 1), rand(5, 1))))
151+
G = Flux.Recur(Flux.GRUCell(rand(5*3, 3), rand(1:20, 5*3, 5), rand(5*3), rand(5, 1)))
152+
G3 = Flux.Recur(Flux.GRUv3Cell(rand(5*3, 3), rand(1:20, 5*2, 5), rand(5*3), Tridiagonal(rand(5, 5)), rand(5, 1)))
153+
154+
for m in [R, L, G, G3]
155+
156+
x1 = rand(3)
157+
x2 = rand(3, 1)
158+
x3 = rand(3, 1, 2)
159+
Flux.reset!(m)
160+
@test size(m(x1)) == (5,)
161+
Flux.reset!(m)
162+
@test size(m(x1)) == (5,) # repeat in case of effect from change in state shape
163+
@test size(m(x2)) == (5, 1)
164+
Flux.reset!(m)
165+
@test size(m(x2)) == (5, 1)
166+
Flux.reset!(m)
167+
@test size(m(x3)) == (5, 1, 2)
168+
Flux.reset!(m)
169+
@test size(m(x3)) == (5, 1, 2)
170+
end
171+
end

0 commit comments

Comments
 (0)