Skip to content

Commit 57ef5c0

Browse files
authored
Make RNN layers accept in => out (#1886)
* let RNN layers accept in => out * make all signatures agree
1 parent 89d5137 commit 57ef5c0

File tree

4 files changed

+44
-36
lines changed

4 files changed

+44
-36
lines changed

docs/src/models/recurrence.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,19 @@ The `Recur` wrapper stores the state between runs in the `m.state` field.
6565
If we use the `RNN(2, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
6666

6767
```julia
68-
julia> RNN(2, 5)
69-
Recur(RNNCell(2, 5, tanh))
68+
julia> RNN(2, 5) # or equivalently RNN(2 => 5)
69+
Recur(
70+
RNNCell(2 => 5, tanh), # 45 parameters
71+
) # Total: 4 trainable arrays, 45 parameters,
72+
# plus 1 non-trainable, 5 parameters, summarysize 412 bytes.
7073
```
7174

7275
Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.
7376

7477
Using these tools, we can now build the model shown in the above diagram with:
7578

7679
```julia
77-
m = Chain(RNN(2, 5), Dense(5 => 1))
80+
m = Chain(RNN(2 => 5), Dense(5 => 1))
7881
```
7982
In this example, each output has only one component.
8083

src/deprecations.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,8 @@ Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; kw...) =
2727
Bilinear((in1, in2) => out, σ; kw...)
2828
Embedding(in::Integer, out::Integer; kw...) = Embedding(in => out; kw...)
2929

30+
RNNCell(in::Integer, out::Integer, σ = tanh; kw...) = RNNCell(in => out, σ; kw...)
31+
LSTMCell(in::Integer, out::Integer; kw...) = LSTMCell(in => out; kw...)
32+
33+
GRUCell(in::Integer, out::Integer; kw...) = GRUCell(in => out; kw...)
34+
GRUv3Cell(in::Integer, out::Integer; kw...) = GRUv3Cell(in => out; kw...)

src/layers/recurrent.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ struct RNNCell{F,A,V,S}
100100
state0::S
101101
end
102102

103-
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
103+
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
104104
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
105105

106106
function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
@@ -113,26 +113,26 @@ end
113113
@functor RNNCell
114114

115115
function Base.show(io::IO, l::RNNCell)
116-
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
116+
print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1))
117117
l.σ == identity || print(io, ", ", l.σ)
118118
print(io, ")")
119119
end
120120

121121
"""
122-
RNN(in::Integer, out::Integer, σ = tanh)
122+
RNN(in => out, σ = tanh)
123123
124124
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
125125
output fed back into the input each time step.
126126
127-
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
127+
The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
128128
129129
This constructor is syntactic sugar for `Recur(RNNCell(a...))`, and so RNNs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
130130
131131
# Examples
132132
```jldoctest
133-
julia> r = RNN(3, 5)
133+
julia> r = RNN(3 => 5)
134134
Recur(
135-
RNNCell(3, 5, tanh), # 50 parameters
135+
RNNCell(3 => 5, tanh), # 50 parameters
136136
) # Total: 4 trainable arrays, 50 parameters,
137137
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
138138
@@ -150,9 +150,9 @@ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
150150
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the following example:
151151
152152
```julia
153-
julia> r = RNN(3, 5)
153+
julia> r = RNN(3 => 5)
154154
Recur(
155-
RNNCell(3, 5, tanh), # 50 parameters
155+
RNNCell(3 => 5, tanh), # 50 parameters
156156
) # Total: 4 trainable arrays, 50 parameters,
157157
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
158158
@@ -187,7 +187,7 @@ struct LSTMCell{A,V,S}
187187
state0::S
188188
end
189189

190-
function LSTMCell(in::Integer, out::Integer;
190+
function LSTMCell((in, out)::Pair;
191191
init = glorot_uniform,
192192
initb = zeros32,
193193
init_state = zeros32)
@@ -208,15 +208,15 @@ end
208208
@functor LSTMCell
209209

210210
Base.show(io::IO, l::LSTMCell) =
211-
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
211+
print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")")
212212

213213
"""
214-
LSTM(in::Integer, out::Integer)
214+
LSTM(in => out)
215215
216216
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
217217
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
218218
219-
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
219+
The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
220220
221221
This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
222222
@@ -225,9 +225,9 @@ for a good overview of the internals.
225225
226226
# Examples
227227
```jldoctest
228-
julia> l = LSTM(3, 5)
228+
julia> l = LSTM(3 => 5)
229229
Recur(
230-
LSTMCell(3, 5), # 190 parameters
230+
LSTMCell(3 => 5), # 190 parameters
231231
) # Total: 5 trainable arrays, 190 parameters,
232232
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
233233
@@ -261,7 +261,7 @@ struct GRUCell{A,V,S}
261261
state0::S
262262
end
263263

264-
GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
264+
GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
265265
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
266266

267267
function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
@@ -276,16 +276,16 @@ end
276276
@functor GRUCell
277277

278278
Base.show(io::IO, l::GRUCell) =
279-
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
279+
print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")")
280280

281281
"""
282-
GRU(in::Integer, out::Integer)
282+
GRU(in => out)
283283
284284
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
285285
RNN but generally exhibits a longer memory span over sequences. This implements
286286
the variant proposed in v1 of the referenced paper.
287287
288-
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
288+
The integer arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
289289
290290
This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
291291
@@ -294,9 +294,9 @@ for a good overview of the internals.
294294
295295
# Examples
296296
```jldoctest
297-
julia> g = GRU(3, 5)
297+
julia> g = GRU(3 => 5)
298298
Recur(
299-
GRUCell(3, 5), # 140 parameters
299+
GRUCell(3 => 5), # 140 parameters
300300
) # Total: 4 trainable arrays, 140 parameters,
301301
# plus 1 non-trainable, 5 parameters, summarysize 792 bytes.
302302
@@ -325,7 +325,7 @@ struct GRUv3Cell{A,V,S}
325325
state0::S
326326
end
327327

328-
GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
328+
GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
329329
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
330330
init(out, out), init_state(out,1))
331331

@@ -341,16 +341,16 @@ end
341341
@functor GRUv3Cell
342342

343343
Base.show(io::IO, l::GRUv3Cell) =
344-
print(io, "GRUv3Cell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
344+
print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")")
345345

346346
"""
347-
GRUv3(in::Integer, out::Integer)
347+
GRUv3(in => out)
348348
349349
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
350350
RNN but generally exhibits a longer memory span over sequences. This implements
351351
the variant proposed in v3 of the referenced paper.
352352
353-
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
353+
The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
354354
355355
This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
356356
@@ -359,9 +359,9 @@ for a good overview of the internals.
359359
360360
# Examples
361361
```jldoctest
362-
julia> g = GRUv3(3, 5)
362+
julia> g = GRUv3(3 => 5)
363363
Recur(
364-
GRUv3Cell(3, 5), # 140 parameters
364+
GRUv3Cell(3 => 5), # 140 parameters
365365
) # Total: 5 trainable arrays, 140 parameters,
366366
# plus 1 non-trainable, 5 parameters, summarysize 848 bytes.
367367

test/layers/recurrent.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
@testset "BPTT-1D" begin
33
seq = [rand(Float32, 2) for i = 1:3]
44
for r [RNN,]
5-
rnn = r(2, 3)
5+
rnn = r(2 => 3)
66
Flux.reset!(rnn)
77
grads_seq = gradient(Flux.params(rnn)) do
88
sum(rnn.(seq)[3])
@@ -24,7 +24,7 @@ end
2424
@testset "BPTT-2D" begin
2525
seq = [rand(Float32, (2, 1)) for i = 1:3]
2626
for r [RNN,]
27-
rnn = r(2, 3)
27+
rnn = r(2 => 3)
2828
Flux.reset!(rnn)
2929
grads_seq = gradient(Flux.params(rnn)) do
3030
sum(rnn.(seq)[3])
@@ -44,7 +44,7 @@ end
4444

4545
@testset "BPTT-3D" begin
4646
seq = rand(Float32, (2, 1, 3))
47-
rnn = RNN(2, 3)
47+
rnn = RNN(2 => 3)
4848
Flux.reset!(rnn)
4949
grads_seq = gradient(Flux.params(rnn)) do
5050
sum(rnn(seq)[:, :, 3])
@@ -70,9 +70,9 @@ end
7070

7171
@testset "RNN-shapes" begin
7272
@testset for R in [RNN, GRU, LSTM, GRUv3]
73-
m1 = R(3, 5)
74-
m2 = R(3, 5)
75-
m3 = R(3, 5)
73+
m1 = R(3 => 5)
74+
m2 = R(3 => 5)
75+
m3 = R(3, 5) # leave one to test the silently deprecated "," not "=>" notation
7676
x1 = rand(Float32, 3)
7777
x2 = rand(Float32, 3, 1)
7878
x3 = rand(Float32, 3, 1, 2)
@@ -90,7 +90,7 @@ end
9090

9191
@testset "RNN-input-state-eltypes" begin
9292
@testset for R in [RNN, GRU, LSTM, GRUv3]
93-
m = R(3, 5)
93+
m = R(3 => 5)
9494
x = rand(Float64, 3, 1)
9595
Flux.reset!(m)
9696
@test_throws MethodError m(x)

0 commit comments

Comments
 (0)