Skip to content

Commit 890f6f6

Browse files
fix recurrence docs
c
1 parent 46b73a8 commit 890f6f6

File tree

1 file changed

+46
-43
lines changed

1 file changed

+46
-43
lines changed

docs/src/models/recurrence.md

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ An aspect to recognize is that in such model, the recurrent cells `A` all refer
1313
In the most basic RNN case, cell A could be defined by the following:
1414

1515
```julia
16-
Wxh = randn(Float32, 5, 2)
16+
Wxh = randn(Float32, 5, 4)
1717
Whh = randn(Float32, 5, 5)
1818
b = randn(Float32, 5)
1919

@@ -22,7 +22,7 @@ function rnn_cell(h, x)
2222
return h, h
2323
end
2424

25-
x = rand(Float32, 2) # dummy data
25+
x = rand(Float32, 4) # dummy data
2626
h = rand(Float32, 5) # initial hidden state
2727

2828
h, y = rnn_cell(h, x)
@@ -37,9 +37,9 @@ There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCel
3737
```julia
3838
using Flux
3939

40-
rnn = Flux.RNNCell(2, 5)
40+
rnn = Flux.RNNCell(4, 5)
4141

42-
x = rand(Float32, 2) # dummy data
42+
x = rand(Float32, 4) # dummy data
4343
h = rand(Float32, 5) # initial hidden state
4444

4545
h, y = rnn(h, x)
@@ -50,7 +50,7 @@ h, y = rnn(h, x)
5050
For the most part, we don't want to manage hidden states ourselves, but to treat our models as being stateful. Flux provides the `Recur` wrapper to do this.
5151

5252
```julia
53-
x = rand(Float32, 2)
53+
x = rand(Float32, 4)
5454
h = rand(Float32, 5)
5555

5656
m = Flux.Recur(rnn, h)
@@ -60,113 +60,116 @@ y = m(x)
6060

6161
The `Recur` wrapper stores the state between runs in the `m.state` field.
6262

63-
If we use the `RNN(2, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
63+
If we use the `RNN(4, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
6464

6565
```julia
66-
julia> RNN(2, 5)
67-
Recur(RNNCell(2, 5, tanh))
66+
julia> RNN(4, 5)
67+
Recur(RNNCell(4, 5, tanh))
6868
```
6969

7070
Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.
7171

7272
Using these tools, we can now build the model shown in the above diagram with:
7373

7474
```julia
75-
m = Chain(RNN(2, 5), Dense(5, 1), x -> reshape(x, :))
75+
m = Chain(RNN(4, 5), Dense(5, 2))
7676
```
77+
In this example, each output has to components.
7778

7879
## Working with sequences
7980

8081
Using the previously defined `m` recurrent model, we can now apply it to a single step from our sequence:
8182

8283
```julia
83-
x = rand(Float32, 2)
84+
julia> x = rand(Float32, 4);
85+
8486
julia> m(x)
85-
1-element Array{Float32,1}:
86-
0.028398542
87+
2-element Vector{Float32}:
88+
-0.12852919
89+
0.009802654
8790
```
8891

8992
The `m(x)` operation would be represented by `x1 -> A -> y1` in our diagram.
90-
If we perform this operation a second time, it will be equivalent to `x2 -> A -> y2` since the model `m` has stored the state resulting from the `x1` step:
91-
92-
```julia
93-
x = rand(Float32, 2)
94-
julia> m(x)
95-
1-element Array{Float32,1}:
96-
0.07381232
97-
```
93+
If we perform this operation a second time, it will be equivalent to `x2 -> A -> y2`
94+
since the model `m` has stored the state resulting from the `x1` step.
9895

99-
Now, instead of computing a single step at a time, we can get the full `y1` to `y3` sequence in a single pass by broadcasting the model on a sequence of data.
96+
Now, instead of computing a single step at a time, we can get the full `y1` to `y3` sequence in a single pass by
97+
iterating the model on a sequence of data.
10098

10199
To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of `length = seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.
102100

103101
```julia
104-
x = [rand(Float32, 2) for i = 1:3]
105-
julia> m.(x)
106-
3-element Array{Array{Float32,1},1}:
107-
[-0.17945863]
108-
[-0.20863166]
109-
[-0.20693761]
102+
julia> x = [rand(Float32, 4) for i = 1:3];
103+
104+
julia> [m(x[i]) for i = 1:3]
105+
3-element Vector{Vector{Float32}}:
106+
[-0.018976994, 0.61098206]
107+
[-0.8924057, -0.7512169]
108+
[-0.34613007, -0.54565114]
110109
```
111110

112111
If for some reason one wants to exclude the first step of the RNN chain for the computation of the loss, that can be handled with:
113112

114113
```julia
115114
function loss(x, y)
116-
sum((Flux.stack(m.(x)[2:end],1) .- y) .^ 2)
115+
m(x[1]) # ignores the output but updates the hidden states
116+
l = 0f0
117+
for i in 2:length(x)
118+
l += sum((m(x[i]) .- y[i-1]).^2)
119+
end
120+
return l
117121
end
118122

119-
y = rand(Float32, 2)
120-
julia> loss(x, y)
121-
1.7021208968648693
123+
y = [rand(Float32, 2) for i=1:2]
124+
loss(x, y)
122125
```
123126

124-
In such model, only `y2` and `y3` are used to compute the loss, hence the target `y` being of length 2. This is a strategy that can be used to easily handle a `seq-to-one` kind of structure, compared to the `seq-to-seq` assumed so far.
127+
In such model, only the last two outputs are used to compute the loss, hence the target `y` being of length 2. This is a strategy that can be used to easily handle a `seq-to-one` kind of structure, compared to the `seq-to-seq` assumed so far.
125128

126129
Alternatively, if one wants to perform some warmup of the sequence, it could be performed once, followed with a regular training where all the steps of the sequence would be considered for the gradient update:
127130

128131
```julia
129132
function loss(x, y)
130-
sum((Flux.stack(m.(x),1) .- y) .^ 2)
133+
sum(sum((m(x[i]) .- y[i]).^2) for i=1:length(x))
131134
end
132135

133-
seq_init = [rand(Float32, 2) for i = 1:1]
134-
seq_1 = [rand(Float32, 2) for i = 1:3]
135-
seq_2 = [rand(Float32, 2) for i = 1:3]
136+
seq_init = [rand(Float32, 4) for i = 1:1]
137+
seq_1 = [rand(Float32, 4) for i = 1:3]
138+
seq_2 = [rand(Float32, 4) for i = 1:3]
136139

137-
y1 = rand(Float32, 3)
138-
y2 = rand(Float32, 3)
140+
y1 = [rand(Float32, 2) for i = 1:3]
141+
y2 = [rand(Float32, 2) for i = 1:3]
139142

140143
X = [seq_1, seq_2]
141144
Y = [y1, y2]
142145
data = zip(X,Y)
143146

144147
Flux.reset!(m)
145-
m.(seq_init)
148+
[m(x) for x in seq_init]
146149

147150
ps = params(m)
148151
opt= ADAM(1e-3)
149152
Flux.train!(loss, ps, data, opt)
150153
```
151154

152-
In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss (we no longer use a subset of `m.(x)` in the loss function).
155+
In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss.
153156

154157
In this scenario, it is important to note that a single continuous sequence is considered. Since the model state is not reset between the 2 batches, the state of the model flows through the batches, which only makes sense in the context where `seq_1` is the continuation of `seq_init` and so on.
155158

156159
Batch size would be 1 here as there's only a single sequence within each batch. If the model was to be trained on multiple independent sequences, then these sequences could be added to the input data as a second dimension. For example, in a language model, each batch would contain multiple independent sentences. In such scenario, if we set the batch size to 4, a single batch would be of the shape:
157160

158161
```julia
159-
batch = [rand(Float32, 2, 4) for i = 1:3]
162+
batch = [rand(Float32, 4, 4) for i = 1:3]
160163
```
161164

162-
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
165+
That would mean that we have 4 sentences (or samples), each with 4 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
163166

164167
In many situations, such as when dealing with a language model, each batch typically contains independent sentences, so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situation, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
165168

166169
```julia
167170
function loss(x, y)
168171
Flux.reset!(m)
169-
sum((Flux.stack(m.(x),1) .- y) .^ 2)
172+
sum(sum((m(x[i]) .- y[i]).^2) for i=1:length(x))
170173
end
171174
```
172175

0 commit comments

Comments
 (0)