Skip to content

Commit cea8f75

Browse files
update docs
1 parent ab24e8e commit cea8f75

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

docs/src/models/recurrence.md

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ An aspect to recognize is that in such model, the recurrent cells `A` all refer
1313
In the most basic RNN case, cell A could be defined by the following:
1414

1515
```julia
16-
Wxh = randn(Float32, 5, 4)
16+
Wxh = randn(Float32, 5, 2)
1717
Whh = randn(Float32, 5, 5)
1818
b = randn(Float32, 5)
1919

@@ -22,7 +22,7 @@ function rnn_cell(h, x)
2222
return h, h
2323
end
2424

25-
x = rand(Float32, 4) # dummy data
25+
x = rand(Float32, 2) # dummy data
2626
h = rand(Float32, 5) # initial hidden state
2727

2828
h, y = rnn_cell(h, x)
@@ -37,9 +37,9 @@ There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCel
3737
```julia
3838
using Flux
3939

40-
rnn = Flux.RNNCell(4, 5)
40+
rnn = Flux.RNNCell(2, 5)
4141

42-
x = rand(Float32, 4) # dummy data
42+
x = rand(Float32, 2) # dummy data
4343
h = rand(Float32, 5) # initial hidden state
4444

4545
h, y = rnn(h, x)
@@ -50,7 +50,7 @@ h, y = rnn(h, x)
5050
For the most part, we don't want to manage hidden states ourselves, but to treat our models as being stateful. Flux provides the `Recur` wrapper to do this.
5151

5252
```julia
53-
x = rand(Float32, 4)
53+
x = rand(Float32, 2)
5454
h = rand(Float32, 5)
5555

5656
m = Flux.Recur(rnn, h)
@@ -60,19 +60,19 @@ y = m(x)
6060

6161
The `Recur` wrapper stores the state between runs in the `m.state` field.
6262

63-
If we use the `RNN(4, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
63+
If we use the `RNN(2, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
6464

6565
```julia
66-
julia> RNN(4, 5)
67-
Recur(RNNCell(4, 5, tanh))
66+
julia> RNN(2, 5)
67+
Recur(RNNCell(2, 5, tanh))
6868
```
6969

7070
Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.
7171

7272
Using these tools, we can now build the model shown in the above diagram with:
7373

7474
```julia
75-
m = Chain(RNN(4, 5), Dense(5, 2))
75+
m = Chain(RNN(2, 5), Dense(5, 2))
7676
```
7777
In this example, each output has to components.
7878

@@ -81,7 +81,7 @@ In this example, each output has to components.
8181
Using the previously defined `m` recurrent model, we can now apply it to a single step from our sequence:
8282

8383
```julia
84-
julia> x = rand(Float32, 4);
84+
julia> x = rand(Float32, 2);
8585

8686
julia> m(x)
8787
2-element Vector{Float32}:
@@ -99,17 +99,17 @@ iterating the model on a sequence of data.
9999
To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of `length = seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.
100100

101101
```julia
102-
julia> x = [rand(Float32, 4) for i = 1:3];
102+
julia> x = [rand(Float32, 2) for i = 1:3];
103103

104-
julia> [m(x[i]) for i = 1:3]
104+
julia> [m(xi) for xi in x]
105105
3-element Vector{Vector{Float32}}:
106106
[-0.018976994, 0.61098206]
107107
[-0.8924057, -0.7512169]
108108
[-0.34613007, -0.54565114]
109109
```
110110

111111
!!! warning "Use of map and broadcast"
112-
Mapping and broadcasting operations with stateful layers such as the one we are considering are discouraged,
112+
Mapping and broadcasting operations with stateful layers such are discouraged,
113113
since the julia language doesn't guarantee a specific execution order.
114114
Therefore, avoid
115115
```julia
@@ -125,12 +125,11 @@ julia> [m(x[i]) for i = 1:3]
125125
If for some reason one wants to exclude the first step of the RNN chain for the computation of the loss, that can be handled with:
126126

127127
```julia
128+
using Flux.Losses: mse
129+
128130
function loss(x, y)
129131
m(x[1]) # ignores the output but updates the hidden states
130-
l = 0f0
131-
for i in 2:length(x)
132-
l += sum((m(x[i]) .- y[i-1]).^2)
133-
end
132+
l = sum(mse(m(xi), yi) for (xi, yi) in zip(x[2:end], y))
134133
return l
135134
end
136135

@@ -144,12 +143,12 @@ Alternatively, if one wants to perform some warmup of the sequence, it could be
144143

145144
```julia
146145
function loss(x, y)
147-
sum(sum((m(x[i]) .- y[i]).^2) for i=1:length(x))
146+
sum(mse(m(xi), yi) for (xi, yi) in zip(x, y))
148147
end
149148

150-
seq_init = [rand(Float32, 4) for i = 1:1]
151-
seq_1 = [rand(Float32, 4) for i = 1:3]
152-
seq_2 = [rand(Float32, 4) for i = 1:3]
149+
seq_init = [rand(Float32, 2)]
150+
seq_1 = [rand(Float32, 2) for i = 1:3]
151+
seq_2 = [rand(Float32, 2) for i = 1:3]
153152

154153
y1 = [rand(Float32, 2) for i = 1:3]
155154
y2 = [rand(Float32, 2) for i = 1:3]
@@ -173,17 +172,17 @@ In this scenario, it is important to note that a single continuous sequence is c
173172
Batch size would be 1 here as there's only a single sequence within each batch. If the model was to be trained on multiple independent sequences, then these sequences could be added to the input data as a second dimension. For example, in a language model, each batch would contain multiple independent sentences. In such scenario, if we set the batch size to 4, a single batch would be of the shape:
174173

175174
```julia
176-
batch = [rand(Float32, 4, 4) for i = 1:3]
175+
batch = [rand(Float32, 2, 4) for i = 1:3]
177176
```
178177

179-
That would mean that we have 4 sentences (or samples), each with 4 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
178+
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
180179

181180
In many situations, such as when dealing with a language model, each batch typically contains independent sentences, so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situation, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
182181

183182
```julia
184183
function loss(x, y)
185184
Flux.reset!(m)
186-
sum(sum((m(x[i]) .- y[i]).^2) for i=1:length(x))
185+
sum(mse(m(xi), yi) for (xi, yi) in zip(x, y))
187186
end
188187
```
189188

0 commit comments

Comments
 (0)