Skip to content

Commit 66a84ef

Browse files
bors[bot]mcognetta
andauthored
Merge #1772
1772: Expand RNN/LSTM/GRU docs r=ToucheSir a=mcognetta This PR adds expanded documentation to the RNN/LSTM/GRU/GRUv3 docs, resolving #1696. It addresses the `in` and `out` parameter meanings and adds a warning about a common gotcha (not calling reset when batch sizes are changed). Co-authored-by: Marco Cognetta <cognetta.marco@gmail.com>
2 parents 2053274 + fbd2ad2 commit 66a84ef

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

src/layers/recurrent.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,57 @@ end
120120
121121
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
122122
output fed back into the input each time step.
123+
124+
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
125+
126+
This constructor is syntactic sugar for `Recur(RNNCell(a...))`, and so RNNs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
127+
128+
# Examples
129+
```jldoctest
130+
julia> r = RNN(3, 5)
131+
Recur(
132+
RNNCell(3, 5, tanh), # 50 parameters
133+
) # Total: 4 trainable arrays, 50 parameters,
134+
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
135+
136+
julia> r(rand(Float32, 3)) |> size
137+
(5,)
138+
139+
julia> Flux.reset!(r);
140+
141+
julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
142+
(5, 10)
143+
```
144+
145+
!!! warning "Batch size changes"
146+
147+
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the following example:
148+
149+
```julia
150+
julia> r = RNN(3, 5)
151+
Recur(
152+
RNNCell(3, 5, tanh), # 50 parameters
153+
) # Total: 4 trainable arrays, 50 parameters,
154+
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
155+
156+
julia> r.state |> size
157+
(5, 1)
158+
159+
julia> r(rand(Float32, 3)) |> size
160+
(5,)
161+
162+
julia> r.state |> size
163+
(5, 1)
164+
165+
julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
166+
(5, 10)
167+
168+
julia> r.state |> size # state shape has changed
169+
(5, 10)
170+
171+
julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
172+
(50,)
173+
```
123174
"""
124175
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
125176
Recur(m::RNNCell) = Recur(m, m.state0)
@@ -178,8 +229,32 @@ Base.show(io::IO, l::LSTMCell) =
178229
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
179230
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
180231
232+
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
233+
234+
This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
235+
181236
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
182237
for a good overview of the internals.
238+
239+
# Examples
240+
```jldoctest
241+
julia> l = LSTM(3, 5)
242+
Recur(
243+
LSTMCell(3, 5), # 190 parameters
244+
) # Total: 5 trainable arrays, 190 parameters,
245+
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
246+
247+
julia> l(rand(Float32, 3)) |> size
248+
(5,)
249+
250+
julia> Flux.reset!(l);
251+
252+
julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
253+
(5, 10)
254+
```
255+
256+
!!! warning "Batch size changes"
257+
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
183258
"""
184259
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
185260
Recur(m::LSTMCell) = Recur(m, m.state0)
@@ -243,8 +318,32 @@ Base.show(io::IO, l::GRUCell) =
243318
RNN but generally exhibits a longer memory span over sequences. This implements
244319
the variant proposed in v1 of the referenced paper.
245320
321+
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
322+
323+
This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
324+
246325
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
247326
for a good overview of the internals.
327+
328+
# Examples
329+
```jldoctest
330+
julia> g = GRU(3, 5)
331+
Recur(
332+
GRUCell(3, 5), # 140 parameters
333+
) # Total: 4 trainable arrays, 140 parameters,
334+
# plus 1 non-trainable, 5 parameters, summarysize 792 bytes.
335+
336+
julia> g(rand(Float32, 3)) |> size
337+
(5,)
338+
339+
julia> Flux.reset!(g);
340+
341+
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
342+
(5, 10)
343+
```
344+
345+
!!! warning "Batch size changes"
346+
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
248347
"""
249348
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
250349
Recur(m::GRUCell) = Recur(m, m.state0)
@@ -297,8 +396,32 @@ Base.show(io::IO, l::GRUv3Cell) =
297396
RNN but generally exhibits a longer memory span over sequences. This implements
298397
the variant proposed in v3 of the referenced paper.
299398
399+
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
400+
401+
This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
402+
300403
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
301404
for a good overview of the internals.
405+
406+
# Examples
407+
```jldoctest
408+
julia> g = GRUv3(3, 5)
409+
Recur(
410+
GRUv3Cell(3, 5), # 140 parameters
411+
) # Total: 5 trainable arrays, 140 parameters,
412+
# plus 1 non-trainable, 5 parameters, summarysize 848 bytes.
413+
414+
julia> g(rand(Float32, 3)) |> size
415+
(5,)
416+
417+
julia> Flux.reset!(g);
418+
419+
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
420+
(5, 10)
421+
```
422+
423+
!!! warning "Batch size changes"
424+
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
302425
"""
303426
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
304427
Recur(m::GRUv3Cell) = Recur(m, m.state0)

0 commit comments

Comments
 (0)