Skip to content

Commit 66dddb9

Browse files
authored
Merge pull request #2075 from mcognetta/rnn_cell_docs
Finish docs for #2073
2 parents 9a8a676 + b26113d commit 66dddb9

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

src/layers/recurrent.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
272272
(50,)
273273
```
274274
275-
# Note:
276-
`RNNCell`s can be constructed directly by specifying the non-linear function, the `W_i` and `W_h` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `W_i` and `W_h` matrices do not need to be the same type, but if `W_h` is `dxd`, then `W_i` should be of shape `dxN`.
275+
# Note:
276+
`RNNCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type, but if `Wh` is `dxd`, then `Wi` should be of shape `dxN`.
277277
278278
```julia
279279
julia> using LinearAlgebra
@@ -282,7 +282,7 @@ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
282282
283283
julia> r(rand(4, 10)) |> size # batch size of 10
284284
(5, 10)
285-
````
285+
```
286286
"""
287287
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
288288
Recur(m::RNNCell) = Recur(m, m.state0)
@@ -351,6 +351,9 @@ julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
351351
352352
!!! warning "Batch size changes"
353353
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
354+
355+
# Note:
356+
`LSTMCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref).
354357
"""
355358
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
356359
Recur(m::LSTMCell) = Recur(m, m.state0)
@@ -420,6 +423,9 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
420423
421424
!!! warning "Batch size changes"
422425
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
426+
427+
# Note:
428+
`GRUCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref).
423429
"""
424430
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
425431
Recur(m::GRUCell) = Recur(m, m.state0)
@@ -485,6 +491,9 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
485491
486492
!!! warning "Batch size changes"
487493
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref).
494+
495+
# Note:
496+
`GRUv3Cell`s can be constructed directly by specifying the non-linear function, the `Wi`, `Wh`, and `Wh_h` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi`, `Wh`, and `Wh_h` matrices do not need to be the same type. See the example in [`RNN`](@ref).
488497
"""
489498
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
490499
Recur(m::GRUv3Cell) = Recur(m, m.state0)

test/layers/recurrent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,4 @@ end
168168
Flux.reset!(m)
169169
@test size(m(x3)) == (5, 1, 2)
170170
end
171-
end
171+
end

0 commit comments

Comments
 (0)