@@ -138,18 +138,8 @@ julia> r(rand(Float32, 3)) |> size
138
138
139
139
julia> Flux.reset!(r);
140
140
141
- julia> r(rand(Float32, 3, 64)) |> size
142
- (5, 64)
143
-
144
- julia> Flux.reset!(r);
145
-
146
- julia> r(rand(Float32, 3))
147
- 5-element Vector{Float32}:
148
- -0.37216917
149
- -0.14777198
150
- 0.2281275
151
- 0.32866752
152
- -0.6388411
141
+ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
142
+ (5, 10)
153
143
154
144
# A demonstration of not using `reset!` when the batch size changes.
155
145
julia> r = RNN(3, 5)
@@ -161,38 +151,21 @@ Recur(
161
151
julia> r.state |> size
162
152
(5, 1)
163
153
164
- julia> r(rand(Float32, 3))
165
- 5-element Vector{Float32}:
166
- 0.3572996
167
- -0.041238427
168
- 0.19673917
169
- -0.36114445
170
- -0.0023919558
154
+ julia> r(rand(Float32, 3)) |> size
155
+ (5,)
171
156
172
157
julia> r.state |> size
173
158
(5, 1)
174
159
175
- julia> r(rand(Float32, 3, 10)) # batch size of 10
176
- 5×10 Matrix{Float32}:
177
- 0.50832 0.409913 0.392907 0.838393 0.297105 0.432568 0.439304 0.677793 0.690217 0.78335
178
- -0.36385 -0.271328 -0.405521 -0.443976 -0.279546 -0.171614 -0.328029 -0.551147 -0.272327 -0.336688
179
- 0.272917 -0.0155508 0.0995184 0.580889 0.0502855 0.0375683 0.163693 0.39545 0.294581 0.461731
180
- -0.353226 -0.924237 -0.816582 -0.694016 -0.530896 -0.783385 -0.584767 -0.854036 -0.832923 -0.730812
181
- 0.418002 0.657771 0.673267 0.388967 0.483295 0.444058 0.490792 0.707697 0.435467 0.350789
160
+ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
161
+ (5, 10)
182
162
183
163
julia> r.state |> size # state shape has changed
184
164
(5, 10)
185
165
186
- julia> r(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
187
- 50-element Vector{Float32}:
188
- 0.8532559
189
- -0.5693587
190
- 0.49786803
191
- ⋮
192
- -0.7722325
193
- 0.46099305
166
+ julia> r(rand(Float32, 3)) |> size # outputs a length 5*10 = 50 vector.
167
+ (50,)
194
168
```
195
-
196
169
"""
197
170
RNN (a... ; ka... ) = Recur (RNNCell (a... ; ka... ))
198
171
Recur (m:: RNNCell ) = Recur (m, m. state0)
@@ -258,68 +231,43 @@ This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs ar
258
231
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
259
232
for a good overview of the internals.
260
233
261
- # Examples
262
- ```jldoctest
263
- julia> l = LSTM(3, 5)
264
- Recur(
265
- LSTMCell(3, 5), # 190 parameters
266
- ) # Total: 5 trainable arrays, 190 parameters,
267
- # plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
268
-
269
- julia> l(rand(Float32, 3)) |> size
270
- (5,)
271
-
272
- julia> Flux.reset!(l);
273
-
274
- julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
275
- (5, 10)
276
-
277
- julia> Flux.reset!(l);
278
-
279
- julia> l(rand(Float32, 3))
280
- 5-element Vector{Float32}:
281
- -0.025144277
282
- 0.03836835
283
- 0.13517386
284
- -0.028824253
285
- -0.057356793
286
-
287
- # A demonstration of not using `reset!` when the batch size changes.
288
- julia> l = LSTM(3, 5)
289
- Recur(
290
- LSTMCell(3, 5), # 190 parameters
291
- ) # Total: 5 trainable arrays, 190 parameters,
292
- # plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
293
-
294
- julia> size.(l.state)
295
- ((5, 1), (5, 1))
296
-
297
- julia> l(rand(Float32, 3))
298
- 5-element Vector{Float32}:
299
- 0.038496178
300
- 0.047853474
301
- 0.025309514
302
- 0.0934924
303
- 0.05440048
304
-
305
- julia> l(rand(Float32, 3, 10)) # batch size of 10
306
- 5×10 Matrix{Float32}:
307
- 0.169775 -0.0268295 0.0985312 0.0335569 0.023051 0.146001 0.0494771 0.12347 0.148342 0.00534695
308
- 0.0784295 0.130255 0.0326518 0.0495609 0.108738 0.10251 0.0519795 0.0673814 0.0804598 0.135432
309
- 0.109187 -0.0267218 0.0772971 0.0200508 0.0108066 0.0921862 0.0346887 0.0831271 0.0978057 -0.00210143
310
- 0.0827624 0.163729 0.10911 0.134769 0.120407 0.0757773 0.0894074 0.130243 0.0895137 0.133424
311
- 0.060574 0.127245 0.0145216 0.0635873 0.108584 0.0954128 0.0529619 0.0665022 0.0689427 0.127494
312
-
313
- julia> size.(l.state) # state shape has changed
314
- ((5, 10), (5, 10))
315
-
316
- julia> l(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
317
- 50-element Vector{Float32}:
318
- 0.07209678
319
- 0.1450204
320
- ⋮
321
- 0.14622498
322
- 0.15595339
234
+ # Examples
235
+ ```jldoctest
236
+ julia> l = LSTM(3, 5)
237
+ Recur(
238
+ LSTMCell(3, 5), # 190 parameters
239
+ ) # Total: 5 trainable arrays, 190 parameters,
240
+ # plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
241
+
242
+ julia> l(rand(Float32, 3)) |> size
243
+ (5,)
244
+
245
+ julia> Flux.reset!(l);
246
+
247
+ julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
248
+ (5, 10)
249
+
250
+ # A demonstration of not using `reset!` when the batch size changes.
251
+ julia> l = LSTM(3, 5)
252
+ Recur(
253
+ LSTMCell(3, 5), # 190 parameters
254
+ ) # Total: 5 trainable arrays, 190 parameters,
255
+ # plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
256
+
257
+ julia> l.state .|> size
258
+ ((5, 1), (5, 1))
259
+
260
+ julia> l(rand(Float32, 3)) |> size
261
+ (5,)
262
+
263
+ julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
264
+ (5, 10)
265
+
266
+ julia> l.state .|> size # state shape has changed
267
+ ((5, 10), (5, 10))
268
+
269
+ julia> l(rand(Float32, 3)) |> size # outputs a length 5*10 = 50 vector.
270
+ (50,)
323
271
```
324
272
"""
325
273
LSTM (a... ; ka... ) = Recur (LSTMCell (a... ; ka... ))
@@ -407,16 +355,6 @@ julia> Flux.reset!(g);
407
355
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
408
356
(5, 10)
409
357
410
- julia> Flux.reset!(g);
411
-
412
- julia> g(rand(Float32, 3))
413
- 5-element Vector{Float32}:
414
- 0.05426188
415
- -0.111508384
416
- 0.04700454
417
- 0.06919164
418
- 0.089212984
419
-
420
358
# A demonstration of not using `reset!` when the batch size changes.
421
359
julia> g = GRU(3, 5)
422
360
Recur(
@@ -427,32 +365,17 @@ Recur(
427
365
julia> g.state |> size
428
366
(5, 1)
429
367
430
- julia> g(rand(Float32, 3))
431
- 5-element Vector{Float32}:
432
- -0.11918676
433
- -0.089210495
434
- 0.027523153
435
- 0.017113047
436
- 0.061968707
437
-
438
- julia> g(rand(Float32, 3, 10)) # batch size of 10
439
- 5×10 Matrix{Float32}:
440
- -0.198102 -0.187499 -0.265959 -0.21598 -0.210867 -0.379202 -0.262658 -0.213773 -0.236976 -0.266929
441
- -0.138773 -0.137587 -0.208564 -0.155394 -0.142374 -0.289558 -0.200516 -0.154471 -0.165038 -0.198165
442
- 0.040142 0.0716526 0.122938 0.0606727 0.00901341 0.0754129 0.107307 0.0551935 0.0366661 0.0648411
443
- 0.0655876 0.0512702 -0.0813906 0.120083 0.0521291 0.175624 0.110025 0.0345626 0.189902 -0.00220774
444
- 0.0756504 0.0913944 0.0982122 0.122272 0.0471702 0.228589 0.168877 0.0778906 0.145469 0.0832033
368
+ julia> g(rand(Float32, 3)) |> size
369
+ (5,)
370
+
371
+ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
372
+ (5, 10)
445
373
446
374
julia> g.state |> size # state shape has changed
447
375
(5, 10)
448
376
449
- julia> g(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
450
- 50-element Vector{Float32}:
451
- -0.2639928
452
- -0.18772684
453
- ⋮
454
- -0.022745812
455
- 0.040191136
377
+ julia> g(rand(Float32, 3)) |> size # outputs a length 5*10 = 50 vector.
378
+ (50,)
456
379
```
457
380
"""
458
381
GRU (a... ; ka... ) = Recur (GRUCell (a... ; ka... ))
@@ -529,16 +452,6 @@ julia> Flux.reset!(g);
529
452
julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
530
453
(5, 10)
531
454
532
- julia> Flux.reset!(g);
533
-
534
- julia> g(rand(Float32, 3))
535
- 5-element Vector{Float32}:
536
- 0.05637428
537
- 0.0084088165
538
- -0.036565308
539
- 0.013599886
540
- -0.0168455
541
-
542
455
# A demonstration of not using `reset!` when the batch size changes.
543
456
julia> g = GRUv3(3, 5)
544
457
Recur(
@@ -549,32 +462,17 @@ Recur(
549
462
julia> g.state |> size
550
463
(5, 1)
551
464
552
- julia> g(rand(Float32, 3))
553
- 5-element Vector{Float32}:
554
- 0.07569726
555
- 0.23686615
556
- -0.01647649
557
- 0.100590095
558
- 0.06330994
559
-
560
- julia> g(rand(Float32, 3, 10)) # batch size of 10
561
- 5×10 Matrix{Float32}:
562
- 0.0187245 0.135969 0.0808607 0.138937 0.0153128 0.0386136 0.0498803 -0.0273552 0.116714 0.0584934
563
- 0.207631 0.146397 0.226232 0.297546 0.28957 0.199815 0.239891 0.27778 0.132326 0.0325415
564
- 0.083468 -0.00669185 -0.0562241 0.00725718 0.0319667 -0.021063 0.0682753 0.0109112 0.0188356 0.0826402
565
- 0.0700071 0.120734 0.108757 0.14339 0.0850359 0.0706199 0.0915005 0.05131 0.105372 0.0507574
566
- 0.0505043 -0.0408188 0.0170817 0.0190653 0.0936475 0.0406348 0.044181 0.139226 -0.0355197 -0.0434937
465
+ julia> g(rand(Float32, 3)) |> size
466
+ (5,)
467
+
468
+ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
469
+ (5, 10)
567
470
568
471
julia> g.state |> size # state shape has changed
569
472
(5, 10)
570
473
571
- julia> g(rand(Float32, 3)) # outputs a length 5*10 = 50 vector.
572
- 50-element Vector{Float32}:
573
- 0.08773954
574
- 0.34562656
575
- ⋮
576
- 0.13768406
577
- -0.015648054
474
+ julia> g(rand(Float32, 3)) |> size # outputs a length 5*10 = 50 vector.
475
+ (50,)
578
476
```
579
477
"""
580
478
GRUv3 (a... ; ka... ) = Recur (GRUv3Cell (a... ; ka... ))
0 commit comments