@@ -251,30 +251,7 @@ julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
251
251
(5, 10)
252
252
```
253
253
254
- The following is a demonstration of when failing to call `reset!` between batch size changes can cause erroneous outputs.
255
-
256
- ```julia
257
- julia> l = LSTM(3, 5)
258
- Recur(
259
- LSTMCell(3, 5), # 190 parameters
260
- ) # Total: 5 trainable arrays, 190 parameters,
261
- # plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.
262
-
263
- julia> l.state .|> size
264
- ((5, 1), (5, 1))
265
-
266
- julia> l(rand(Float32, 3)) |> size
267
- (5,)
268
-
269
- julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
270
- (5, 10)
271
-
272
- julia> l.state .|> size # state shape has changed
273
- ((5, 10), (5, 10))
274
-
275
- julia> l(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
276
- (50,)
277
- ```
254
+ !!! warning Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
278
255
"""
279
256
LSTM (a... ; ka... ) = Recur (LSTMCell (a... ; ka... ))
280
257
Recur (m:: LSTMCell ) = Recur (m, m. state0)
@@ -362,30 +339,7 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
362
339
(5, 10)
363
340
```
364
341
365
- The following is a demonstration of when failing to call `reset!` between batch size changes can cause erroneous outputs.
366
-
367
- ```julia
368
- julia> g = GRU(3, 5)
369
- Recur(
370
- GRUCell(3, 5), # 140 parameters
371
- ) # Total: 4 trainable arrays, 140 parameters,
372
- # plus 1 non-trainable, 5 parameters, summarysize 792 bytes.
373
-
374
- julia> g.state |> size
375
- (5, 1)
376
-
377
- julia> g(rand(Float32, 3)) |> size
378
- (5,)
379
-
380
- julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
381
- (5, 10)
382
-
383
- julia> g.state |> size # state shape has changed
384
- (5, 10)
385
-
386
- julia> g(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
387
- (50,)
388
- ```
342
+ !!! warning Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
389
343
"""
390
344
GRU (a... ; ka... ) = Recur (GRUCell (a... ; ka... ))
391
345
Recur (m:: GRUCell ) = Recur (m, m. state0)
@@ -462,30 +416,7 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
462
416
(5, 10)
463
417
```
464
418
465
- The following is a demonstration of when failing to call `reset!` between batch size changes can cause erroneous outputs.
466
-
467
- ```julia
468
- julia> g = GRUv3(3, 5)
469
- Recur(
470
- GRUv3Cell(3, 5), # 140 parameters
471
- ) # Total: 5 trainable arrays, 140 parameters,
472
- # plus 1 non-trainable, 5 parameters, summarysize 848 bytes.
473
-
474
- julia> g.state |> size
475
- (5, 1)
476
-
477
- julia> g(rand(Float32, 3)) |> size
478
- (5,)
479
-
480
- julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
481
- (5, 10)
482
-
483
- julia> g.state |> size # state shape has changed
484
- (5, 10)
485
-
486
- julia> g(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
487
- (50,)
488
- ```
419
+ !!! warning Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
489
420
"""
490
421
GRUv3 (a... ; ka... ) = Recur (GRUv3Cell (a... ; ka... ))
491
422
Recur (m:: GRUv3Cell ) = Recur (m, m. state0)
0 commit comments