Skip to content

Commit fbe3913

Browse files
committed
condensed examples
1 parent 20f9535 commit fbe3913

File tree

1 file changed

+59
-161
lines changed

1 file changed

+59
-161
lines changed

src/layers/recurrent.jl

Lines changed: 59 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,8 @@ julia> r(rand(Float32, 3)) |> size
138138
139139
julia> Flux.reset!(r);
140140
141-
julia> r(rand(Float32, 3, 64)) |> size
142-
(5, 64)
143-
144-
julia> Flux.reset!(r);
145-
146-
julia> r(rand(Float32, 3))
147-
5-element Vector{Float32}:
148-
-0.37216917
149-
-0.14777198
150-
0.2281275
151-
0.32866752
152-
-0.6388411
141+
julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
142+
(5, 10)
153143
154144
# A demonstration of not using `reset!` when the batch size changes.
155145
julia> r = RNN(3, 5)
@@ -161,38 +151,21 @@ Recur(
161151
julia> r.state |> size
162152
(5, 1)
163153
164-
julia> r(rand(Float32, 3))
165-
5-element Vector{Float32}:
166-
0.3572996
167-
-0.041238427
168-
0.19673917
169-
-0.36114445
170-
-0.0023919558
154+
julia> r(rand(Float32, 3)) |> size
155+
(5,)
171156
172157
julia> r.state |> size
173158
(5, 1)
174159
175-
julia> r(rand(Float32, 3, 10)) # batch size of 10
176-
5×10 Matrix{Float32}:
177-
0.50832 0.409913 0.392907 0.838393 0.297105 0.432568 0.439304 0.677793 0.690217 0.78335
178-
-0.36385 -0.271328 -0.405521 -0.443976 -0.279546 -0.171614 -0.328029 -0.551147 -0.272327 -0.336688
179-
0.272917 -0.0155508 0.0995184 0.580889 0.0502855 0.0375683 0.163693 0.39545 0.294581 0.461731
180-
-0.353226 -0.924237 -0.816582 -0.694016 -0.530896 -0.783385 -0.584767 -0.854036 -0.832923 -0.730812
181-
0.418002 0.657771 0.673267 0.388967 0.483295 0.444058 0.490792 0.707697 0.435467 0.350789
160+
julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
161+
(5, 10)
182162
183163
julia> r.state |> size # state shape has changed
184164
(5, 10)
185165
186-
julia> r(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
187-
50-element Vector{Float32}:
188-
0.8532559
189-
-0.5693587
190-
0.49786803
191-
192-
-0.7722325
193-
0.46099305
166+
julia> r(rand(Float32, 3)) |> size # outputs a length 5*10 = 50 vector.
167+
(50,)
194168
```
195-
196169
"""
197170
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
198171
Recur(m::RNNCell) = Recur(m, m.state0)
@@ -258,68 +231,43 @@ This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs ar
258231
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
259232
for a good overview of the internals.
260233
261-
# Examples
262-
```jldoctest
263-
julia> l = LSTM(3, 5)
264-
Recur(
265-
LSTMCell(3, 5), # 190 parameters
266-
) # Total: 5 trainable arrays, 190 parameters,
267-
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
268-
269-
julia> l(rand(Float32, 3)) |> size
270-
(5,)
271-
272-
julia> Flux.reset!(l);
273-
274-
julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
275-
(5, 10)
276-
277-
julia> Flux.reset!(l);
278-
279-
julia> l(rand(Float32, 3))
280-
5-element Vector{Float32}:
281-
-0.025144277
282-
0.03836835
283-
0.13517386
284-
-0.028824253
285-
-0.057356793
286-
287-
# A demonstration of not using `reset!` when the batch size changes.
288-
julia> l = LSTM(3, 5)
289-
Recur(
290-
LSTMCell(3, 5), # 190 parameters
291-
) # Total: 5 trainable arrays, 190 parameters,
292-
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
293-
294-
julia> size.(l.state)
295-
((5, 1), (5, 1))
296-
297-
julia> l(rand(Float32, 3))
298-
5-element Vector{Float32}:
299-
0.038496178
300-
0.047853474
301-
0.025309514
302-
0.0934924
303-
0.05440048
304-
305-
julia> l(rand(Float32, 3, 10)) # batch size of 10
306-
5×10 Matrix{Float32}:
307-
0.169775 -0.0268295 0.0985312 0.0335569 0.023051 0.146001 0.0494771 0.12347 0.148342 0.00534695
308-
0.0784295 0.130255 0.0326518 0.0495609 0.108738 0.10251 0.0519795 0.0673814 0.0804598 0.135432
309-
0.109187 -0.0267218 0.0772971 0.0200508 0.0108066 0.0921862 0.0346887 0.0831271 0.0978057 -0.00210143
310-
0.0827624 0.163729 0.10911 0.134769 0.120407 0.0757773 0.0894074 0.130243 0.0895137 0.133424
311-
0.060574 0.127245 0.0145216 0.0635873 0.108584 0.0954128 0.0529619 0.0665022 0.0689427 0.127494
312-
313-
julia> size.(l.state) # state shape has changed
314-
((5, 10), (5, 10))
315-
316-
julia> l(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
317-
50-element Vector{Float32}:
318-
0.07209678
319-
0.1450204
320-
321-
0.14622498
322-
0.15595339
234+
# Examples
235+
```jldoctest
236+
julia> l = LSTM(3, 5)
237+
Recur(
238+
LSTMCell(3, 5), # 190 parameters
239+
) # Total: 5 trainable arrays, 190 parameters,
240+
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
241+
242+
julia> l(rand(Float32, 3)) |> size
243+
(5,)
244+
245+
julia> Flux.reset!(l);
246+
247+
julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
248+
(5, 10)
249+
250+
# A demonstration of not using `reset!` when the batch size changes.
251+
julia> l = LSTM(3, 5)
252+
Recur(
253+
LSTMCell(3, 5), # 190 parameters
254+
) # Total: 5 trainable arrays, 190 parameters,
255+
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
256+
257+
julia> l.state .|> size
258+
((5, 1), (5, 1))
259+
260+
julia> l(rand(Float32, 3)) |> size
261+
(5,)
262+
263+
julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
264+
(5, 10)
265+
266+
julia> l.state .|> size # state shape has changed
267+
((5, 10), (5, 10))
268+
269+
julia> l(rand(Float32, 3)) |> size # outputs a length 5*10 = 50 vector.
270+
(50,)
323271
```
324272
"""
325273
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
@@ -407,16 +355,6 @@ julia> Flux.reset!(g);
407355
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
408356
(5, 10)
409357
410-
julia> Flux.reset!(g);
411-
412-
julia> g(rand(Float32, 3))
413-
5-element Vector{Float32}:
414-
0.05426188
415-
-0.111508384
416-
0.04700454
417-
0.06919164
418-
0.089212984
419-
420358
# A demonstration of not using `reset!` when the batch size changes.
421359
julia> g = GRU(3, 5)
422360
Recur(
@@ -427,32 +365,17 @@ Recur(
427365
julia> g.state |> size
428366
(5, 1)
429367
430-
julia> g(rand(Float32, 3))
431-
5-element Vector{Float32}:
432-
-0.11918676
433-
-0.089210495
434-
0.027523153
435-
0.017113047
436-
0.061968707
437-
438-
julia> g(rand(Float32, 3, 10)) # batch size of 10
439-
5×10 Matrix{Float32}:
440-
-0.198102 -0.187499 -0.265959 -0.21598 -0.210867 -0.379202 -0.262658 -0.213773 -0.236976 -0.266929
441-
-0.138773 -0.137587 -0.208564 -0.155394 -0.142374 -0.289558 -0.200516 -0.154471 -0.165038 -0.198165
442-
0.040142 0.0716526 0.122938 0.0606727 0.00901341 0.0754129 0.107307 0.0551935 0.0366661 0.0648411
443-
0.0655876 0.0512702 -0.0813906 0.120083 0.0521291 0.175624 0.110025 0.0345626 0.189902 -0.00220774
444-
0.0756504 0.0913944 0.0982122 0.122272 0.0471702 0.228589 0.168877 0.0778906 0.145469 0.0832033
368+
julia> g(rand(Float32, 3)) |> size
369+
(5,)
370+
371+
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
372+
(5, 10)
445373
446374
julia> g.state |> size # state shape has changed
447375
(5, 10)
448376
449-
julia> g(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
450-
50-element Vector{Float32}:
451-
-0.2639928
452-
-0.18772684
453-
454-
-0.022745812
455-
0.040191136
377+
julia> g(rand(Float32, 3)) |> size # outputs a length 5*10 = 50 vector.
378+
(50,)
456379
```
457380
"""
458381
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
@@ -529,16 +452,6 @@ julia> Flux.reset!(g);
529452
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
530453
(5, 10)
531454
532-
julia> Flux.reset!(g);
533-
534-
julia> g(rand(Float32, 3))
535-
5-element Vector{Float32}:
536-
0.05637428
537-
0.0084088165
538-
-0.036565308
539-
0.013599886
540-
-0.0168455
541-
542455
# A demonstration of not using `reset!` when the batch size changes.
543456
julia> g = GRUv3(3, 5)
544457
Recur(
@@ -549,32 +462,17 @@ Recur(
549462
julia> g.state |> size
550463
(5, 1)
551464
552-
julia> g(rand(Float32, 3))
553-
5-element Vector{Float32}:
554-
0.07569726
555-
0.23686615
556-
-0.01647649
557-
0.100590095
558-
0.06330994
559-
560-
julia> g(rand(Float32, 3, 10)) # batch size of 10
561-
5×10 Matrix{Float32}:
562-
0.0187245 0.135969 0.0808607 0.138937 0.0153128 0.0386136 0.0498803 -0.0273552 0.116714 0.0584934
563-
0.207631 0.146397 0.226232 0.297546 0.28957 0.199815 0.239891 0.27778 0.132326 0.0325415
564-
0.083468 -0.00669185 -0.0562241 0.00725718 0.0319667 -0.021063 0.0682753 0.0109112 0.0188356 0.0826402
565-
0.0700071 0.120734 0.108757 0.14339 0.0850359 0.0706199 0.0915005 0.05131 0.105372 0.0507574
566-
0.0505043 -0.0408188 0.0170817 0.0190653 0.0936475 0.0406348 0.044181 0.139226 -0.0355197 -0.0434937
465+
julia> g(rand(Float32, 3)) |> size
466+
(5,)
467+
468+
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
469+
(5, 10)
567470
568471
julia> g.state |> size # state shape has changed
569472
(5, 10)
570473
571-
julia> g(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
572-
50-element Vector{Float32}:
573-
0.08773954
574-
0.34562656
575-
576-
0.13768406
577-
-0.015648054
474+
julia> g(rand(Float32, 3)) |> size # outputs a length 5*10 = 50 vector.
475+
(50,)
578476
```
579477
"""
580478
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))

0 commit comments

Comments
 (0)