|
120 | 120 |
|
121 | 121 | The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
|
122 | 122 | output fed back into the input each time step.
|
| 123 | +
|
| 124 | +The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. |
| 125 | +
|
| 126 | +This constructor is syntactic sugar for `Recur(RNNCell(a...))`, and so RNNs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. |
| 127 | +
|
| 128 | +# Examples |
| 129 | +```jldoctest |
| 130 | +julia> r = RNN(3, 5) |
| 131 | +Recur( |
| 132 | + RNNCell(3, 5, tanh), # 50 parameters |
| 133 | +) # Total: 4 trainable arrays, 50 parameters, |
| 134 | + # plus 1 non-trainable, 5 parameters, summarysize 432 bytes. |
| 135 | +
|
| 136 | +julia> r(rand(Float32, 3)) |> size |
| 137 | +(5,) |
| 138 | +
|
| 139 | +julia> Flux.reset!(r); |
| 140 | +
|
| 141 | +julia> r(rand(Float32, 3, 10)) |> size # batch size of 10 |
| 142 | +(5, 10) |
| 143 | +``` |
| 144 | +
|
| 145 | +!!! warning "Batch size changes" |
| 146 | + |
| 147 | + Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the following example: |
| 148 | +
|
| 149 | + ```julia |
| 150 | + julia> r = RNN(3, 5) |
| 151 | + Recur( |
| 152 | + RNNCell(3, 5, tanh), # 50 parameters |
| 153 | + ) # Total: 4 trainable arrays, 50 parameters, |
| 154 | + # plus 1 non-trainable, 5 parameters, summarysize 432 bytes. |
| 155 | +
|
| 156 | + julia> r.state |> size |
| 157 | + (5, 1) |
| 158 | +
|
| 159 | + julia> r(rand(Float32, 3)) |> size |
| 160 | + (5,) |
| 161 | +
|
| 162 | + julia> r.state |> size |
| 163 | + (5, 1) |
| 164 | +
|
| 165 | + julia> r(rand(Float32, 3, 10)) |> size # batch size of 10 |
| 166 | + (5, 10) |
| 167 | +
|
| 168 | + julia> r.state |> size # state shape has changed |
| 169 | + (5, 10) |
| 170 | +
|
| 171 | + julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector. |
| 172 | + (50,) |
| 173 | + ``` |
123 | 174 | """
|
124 | 175 | RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
125 | 176 | Recur(m::RNNCell) = Recur(m, m.state0)
|
@@ -178,8 +229,32 @@ Base.show(io::IO, l::LSTMCell) =
|
178 | 229 | [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
|
179 | 230 | recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
|
180 | 231 |
|
| 232 | +The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. |
| 233 | +
|
| 234 | +This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. |
| 235 | +
|
181 | 236 | See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
182 | 237 | for a good overview of the internals.
|
| 238 | +
|
| 239 | +# Examples |
| 240 | +```jldoctest |
| 241 | +julia> l = LSTM(3, 5) |
| 242 | +Recur( |
| 243 | + LSTMCell(3, 5), # 190 parameters |
| 244 | +) # Total: 5 trainable arrays, 190 parameters, |
| 245 | + # plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB. |
| 246 | +
|
| 247 | +julia> l(rand(Float32, 3)) |> size |
| 248 | +(5,) |
| 249 | +
|
| 250 | +julia> Flux.reset!(l); |
| 251 | +
|
| 252 | +julia> l(rand(Float32, 3, 10)) |> size # batch size of 10 |
| 253 | +(5, 10) |
| 254 | +``` |
| 255 | +
|
| 256 | +!!! warning "Batch size changes" |
| 257 | + Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref). |
183 | 258 | """
|
184 | 259 | LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
185 | 260 | Recur(m::LSTMCell) = Recur(m, m.state0)
|
@@ -243,8 +318,32 @@ Base.show(io::IO, l::GRUCell) =
|
243 | 318 | RNN but generally exhibits a longer memory span over sequences. This implements
|
244 | 319 | the variant proposed in v1 of the referenced paper.
|
245 | 320 |
|
| 321 | +The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. |
| 322 | +
|
| 323 | +This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. |
| 324 | +
|
246 | 325 | See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
247 | 326 | for a good overview of the internals.
|
| 327 | +
|
| 328 | +# Examples |
| 329 | +```jldoctest |
| 330 | +julia> g = GRU(3, 5) |
| 331 | +Recur( |
| 332 | + GRUCell(3, 5), # 140 parameters |
| 333 | +) # Total: 4 trainable arrays, 140 parameters, |
| 334 | + # plus 1 non-trainable, 5 parameters, summarysize 792 bytes. |
| 335 | +
|
| 336 | +julia> g(rand(Float32, 3)) |> size |
| 337 | +(5,) |
| 338 | +
|
| 339 | +julia> Flux.reset!(g); |
| 340 | +
|
| 341 | +julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 |
| 342 | +(5, 10) |
| 343 | +``` |
| 344 | +
|
| 345 | +!!! warning "Batch size changes" |
| 346 | + Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref). |
248 | 347 | """
|
249 | 348 | GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
|
250 | 349 | Recur(m::GRUCell) = Recur(m, m.state0)
|
@@ -297,8 +396,32 @@ Base.show(io::IO, l::GRUv3Cell) =
|
297 | 396 | RNN but generally exhibits a longer memory span over sequences. This implements
|
298 | 397 | the variant proposed in v3 of the referenced paper.
|
299 | 398 |
|
| 399 | +The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. |
| 400 | +
|
| 401 | +This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. |
| 402 | +
|
300 | 403 | See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
|
301 | 404 | for a good overview of the internals.
|
| 405 | +
|
| 406 | +# Examples |
| 407 | +```jldoctest |
| 408 | +julia> g = GRUv3(3, 5) |
| 409 | +Recur( |
| 410 | + GRUv3Cell(3, 5), # 140 parameters |
| 411 | +) # Total: 5 trainable arrays, 140 parameters, |
| 412 | + # plus 1 non-trainable, 5 parameters, summarysize 848 bytes. |
| 413 | +
|
| 414 | +julia> g(rand(Float32, 3)) |> size |
| 415 | +(5,) |
| 416 | +
|
| 417 | +julia> Flux.reset!(g); |
| 418 | +
|
| 419 | +julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 |
| 420 | +(5, 10) |
| 421 | +``` |
| 422 | +
|
| 423 | +!!! warning "Batch size changes" |
| 424 | + Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref). |
302 | 425 | """
|
303 | 426 | GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
|
304 | 427 | Recur(m::GRUv3Cell) = Recur(m, m.state0)
|
|
0 commit comments