Skip to content

Commit 20f9535

Browse files
committed
update rnn docs
1 parent ea26f45 commit 20f9535

File tree

1 file changed

+277
-0
lines changed

1 file changed

+277
-0
lines changed

src/layers/recurrent.jl

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,79 @@ end
120120
121121
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
122122
output fed back into the input each time step.
123+
124+
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
125+
126+
This constructor is syntactic sugar for `Recur(RNNCell(a...))`, and so RNNs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
127+
128+
# Examples
129+
```jldoctest
130+
julia> r = RNN(3, 5)
131+
Recur(
132+
RNNCell(3, 5, tanh), # 50 parameters
133+
) # Total: 4 trainable arrays, 50 parameters,
134+
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
135+
136+
julia> r(rand(Float32, 3)) |> size
137+
(5,)
138+
139+
julia> Flux.reset!(r);
140+
141+
julia> r(rand(Float32, 3, 64)) |> size
142+
(5, 64)
143+
144+
julia> Flux.reset!(r);
145+
146+
julia> r(rand(Float32, 3))
147+
5-element Vector{Float32}:
148+
-0.37216917
149+
-0.14777198
150+
0.2281275
151+
0.32866752
152+
-0.6388411
153+
154+
# A demonstration of not using `reset!` when the batch size changes.
155+
julia> r = RNN(3, 5)
156+
Recur(
157+
RNNCell(3, 5, tanh), # 50 parameters
158+
) # Total: 4 trainable arrays, 50 parameters,
159+
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
160+
161+
julia> r.state |> size
162+
(5, 1)
163+
164+
julia> r(rand(Float32, 3))
165+
5-element Vector{Float32}:
166+
0.3572996
167+
-0.041238427
168+
0.19673917
169+
-0.36114445
170+
-0.0023919558
171+
172+
julia> r.state |> size
173+
(5, 1)
174+
175+
julia> r(rand(Float32, 3, 10)) # batch size of 10
176+
5×10 Matrix{Float32}:
177+
0.50832 0.409913 0.392907 0.838393 0.297105 0.432568 0.439304 0.677793 0.690217 0.78335
178+
-0.36385 -0.271328 -0.405521 -0.443976 -0.279546 -0.171614 -0.328029 -0.551147 -0.272327 -0.336688
179+
0.272917 -0.0155508 0.0995184 0.580889 0.0502855 0.0375683 0.163693 0.39545 0.294581 0.461731
180+
-0.353226 -0.924237 -0.816582 -0.694016 -0.530896 -0.783385 -0.584767 -0.854036 -0.832923 -0.730812
181+
0.418002 0.657771 0.673267 0.388967 0.483295 0.444058 0.490792 0.707697 0.435467 0.350789
182+
183+
julia> r.state |> size # state shape has changed
184+
(5, 10)
185+
186+
julia> r(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
187+
50-element Vector{Float32}:
188+
0.8532559
189+
-0.5693587
190+
0.49786803
191+
192+
-0.7722325
193+
0.46099305
194+
```
195+
123196
"""
124197
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
125198
Recur(m::RNNCell) = Recur(m, m.state0)
@@ -178,8 +251,76 @@ Base.show(io::IO, l::LSTMCell) =
178251
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
179252
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.
180253
254+
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
255+
256+
This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
257+
181258
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
182259
for a good overview of the internals.
260+
261+
# Examples
262+
```jldoctest
263+
julia> l = LSTM(3, 5)
264+
Recur(
265+
LSTMCell(3, 5), # 190 parameters
266+
) # Total: 5 trainable arrays, 190 parameters,
267+
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
268+
269+
julia> l(rand(Float32, 3)) |> size
270+
(5,)
271+
272+
julia> Flux.reset!(l);
273+
274+
julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
275+
(5, 10)
276+
277+
julia> Flux.reset!(l);
278+
279+
julia> l(rand(Float32, 3))
280+
5-element Vector{Float32}:
281+
-0.025144277
282+
0.03836835
283+
0.13517386
284+
-0.028824253
285+
-0.057356793
286+
287+
# A demonstration of not using `reset!` when the batch size changes.
288+
julia> l = LSTM(3, 5)
289+
Recur(
290+
LSTMCell(3, 5), # 190 parameters
291+
) # Total: 5 trainable arrays, 190 parameters,
292+
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
293+
294+
julia> size.(l.state)
295+
((5, 1), (5, 1))
296+
297+
julia> l(rand(Float32, 3))
298+
5-element Vector{Float32}:
299+
0.038496178
300+
0.047853474
301+
0.025309514
302+
0.0934924
303+
0.05440048
304+
305+
julia> l(rand(Float32, 3, 10)) # batch size of 10
306+
5×10 Matrix{Float32}:
307+
0.169775 -0.0268295 0.0985312 0.0335569 0.023051 0.146001 0.0494771 0.12347 0.148342 0.00534695
308+
0.0784295 0.130255 0.0326518 0.0495609 0.108738 0.10251 0.0519795 0.0673814 0.0804598 0.135432
309+
0.109187 -0.0267218 0.0772971 0.0200508 0.0108066 0.0921862 0.0346887 0.0831271 0.0978057 -0.00210143
310+
0.0827624 0.163729 0.10911 0.134769 0.120407 0.0757773 0.0894074 0.130243 0.0895137 0.133424
311+
0.060574 0.127245 0.0145216 0.0635873 0.108584 0.0954128 0.0529619 0.0665022 0.0689427 0.127494
312+
313+
julia> size.(l.state) # state shape has changed
314+
((5, 10), (5, 10))
315+
316+
julia> l(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
317+
50-element Vector{Float32}:
318+
0.07209678
319+
0.1450204
320+
321+
0.14622498
322+
0.15595339
323+
```
183324
"""
184325
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
185326
Recur(m::LSTMCell) = Recur(m, m.state0)
@@ -243,8 +384,76 @@ Base.show(io::IO, l::GRUCell) =
243384
RNN but generally exhibits a longer memory span over sequences. This implements
244385
the variant proposed in v1 of the referenced paper.
245386
387+
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
388+
389+
This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
390+
246391
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
247392
for a good overview of the internals.
393+
394+
# Examples
395+
```jldoctest
396+
julia> g = GRU(3, 5)
397+
Recur(
398+
GRUCell(3, 5), # 140 parameters
399+
) # Total: 4 trainable arrays, 140 parameters,
400+
# plus 1 non-trainable, 5 parameters, summarysize 792 bytes.
401+
402+
julia> g(rand(Float32, 3)) |> size
403+
(5,)
404+
405+
julia> Flux.reset!(g);
406+
407+
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
408+
(5, 10)
409+
410+
julia> Flux.reset!(g);
411+
412+
julia> g(rand(Float32, 3))
413+
5-element Vector{Float32}:
414+
0.05426188
415+
-0.111508384
416+
0.04700454
417+
0.06919164
418+
0.089212984
419+
420+
# A demonstration of not using `reset!` when the batch size changes.
421+
julia> g = GRU(3, 5)
422+
Recur(
423+
GRUCell(3, 5), # 140 parameters
424+
) # Total: 4 trainable arrays, 140 parameters,
425+
# plus 1 non-trainable, 5 parameters, summarysize 792 bytes.
426+
427+
julia> g.state |> size
428+
(5, 1)
429+
430+
julia> g(rand(Float32, 3))
431+
5-element Vector{Float32}:
432+
-0.11918676
433+
-0.089210495
434+
0.027523153
435+
0.017113047
436+
0.061968707
437+
438+
julia> g(rand(Float32, 3, 10)) # batch size of 10
439+
5×10 Matrix{Float32}:
440+
-0.198102 -0.187499 -0.265959 -0.21598 -0.210867 -0.379202 -0.262658 -0.213773 -0.236976 -0.266929
441+
-0.138773 -0.137587 -0.208564 -0.155394 -0.142374 -0.289558 -0.200516 -0.154471 -0.165038 -0.198165
442+
0.040142 0.0716526 0.122938 0.0606727 0.00901341 0.0754129 0.107307 0.0551935 0.0366661 0.0648411
443+
0.0655876 0.0512702 -0.0813906 0.120083 0.0521291 0.175624 0.110025 0.0345626 0.189902 -0.00220774
444+
0.0756504 0.0913944 0.0982122 0.122272 0.0471702 0.228589 0.168877 0.0778906 0.145469 0.0832033
445+
446+
julia> g.state |> size # state shape has changed
447+
(5, 10)
448+
449+
julia> g(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
450+
50-element Vector{Float32}:
451+
-0.2639928
452+
-0.18772684
453+
454+
-0.022745812
455+
0.040191136
456+
```
248457
"""
249458
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
250459
Recur(m::GRUCell) = Recur(m, m.state0)
@@ -297,8 +506,76 @@ Base.show(io::IO, l::GRUv3Cell) =
297506
RNN but generally exhibits a longer memory span over sequences. This implements
298507
the variant proposed in v3 of the referenced paper.
299508
509+
The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
510+
511+
This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.
512+
300513
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
301514
for a good overview of the internals.
515+
516+
# Examples
517+
```jldoctest
518+
julia> g = GRUv3(3, 5)
519+
Recur(
520+
GRUv3Cell(3, 5), # 140 parameters
521+
) # Total: 5 trainable arrays, 140 parameters,
522+
# plus 1 non-trainable, 5 parameters, summarysize 848 bytes.
523+
524+
julia> g(rand(Float32, 3)) |> size
525+
(5,)
526+
527+
julia> Flux.reset!(g);
528+
529+
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
530+
(5, 10)
531+
532+
julia> Flux.reset!(g);
533+
534+
julia> g(rand(Float32, 3))
535+
5-element Vector{Float32}:
536+
0.05637428
537+
0.0084088165
538+
-0.036565308
539+
0.013599886
540+
-0.0168455
541+
542+
# A demonstration of not using `reset!` when the batch size changes.
543+
julia> g = GRUv3(3, 5)
544+
Recur(
545+
GRUv3Cell(3, 5), # 140 parameters
546+
) # Total: 5 trainable arrays, 140 parameters,
547+
# plus 1 non-trainable, 5 parameters, summarysize 848 bytes.
548+
549+
julia> g.state |> size
550+
(5, 1)
551+
552+
julia> g(rand(Float32, 3))
553+
5-element Vector{Float32}:
554+
0.07569726
555+
0.23686615
556+
-0.01647649
557+
0.100590095
558+
0.06330994
559+
560+
julia> g(rand(Float32, 3, 10)) # batch size of 10
561+
5×10 Matrix{Float32}:
562+
0.0187245 0.135969 0.0808607 0.138937 0.0153128 0.0386136 0.0498803 -0.0273552 0.116714 0.0584934
563+
0.207631 0.146397 0.226232 0.297546 0.28957 0.199815 0.239891 0.27778 0.132326 0.0325415
564+
0.083468 -0.00669185 -0.0562241 0.00725718 0.0319667 -0.021063 0.0682753 0.0109112 0.0188356 0.0826402
565+
0.0700071 0.120734 0.108757 0.14339 0.0850359 0.0706199 0.0915005 0.05131 0.105372 0.0507574
566+
0.0505043 -0.0408188 0.0170817 0.0190653 0.0936475 0.0406348 0.044181 0.139226 -0.0355197 -0.0434937
567+
568+
julia> g.state |> size # state shape has changed
569+
(5, 10)
570+
571+
julia> g(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
572+
50-element Vector{Float32}:
573+
0.08773954
574+
0.34562656
575+
576+
0.13768406
577+
-0.015648054
578+
```
302579
"""
303580
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
304581
Recur(m::GRUv3Cell) = Recur(m, m.state0)

0 commit comments

Comments
 (0)