You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/models/recurrence.md
+46-43Lines changed: 46 additions & 43 deletions
Original file line number
Diff line number
Diff line change
@@ -13,7 +13,7 @@ An aspect to recognize is that in such model, the recurrent cells `A` all refer
13
13
In the most basic RNN case, cell A could be defined by the following:
14
14
15
15
```julia
16
-
Wxh =randn(Float32, 5, 2)
16
+
Wxh =randn(Float32, 5, 4)
17
17
Whh =randn(Float32, 5, 5)
18
18
b =randn(Float32, 5)
19
19
@@ -22,7 +22,7 @@ function rnn_cell(h, x)
22
22
return h, h
23
23
end
24
24
25
-
x =rand(Float32, 2) # dummy data
25
+
x =rand(Float32, 4) # dummy data
26
26
h =rand(Float32, 5) # initial hidden state
27
27
28
28
h, y =rnn_cell(h, x)
@@ -37,9 +37,9 @@ There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCel
37
37
```julia
38
38
using Flux
39
39
40
-
rnn = Flux.RNNCell(2, 5)
40
+
rnn = Flux.RNNCell(4, 5)
41
41
42
-
x =rand(Float32, 2) # dummy data
42
+
x =rand(Float32, 4) # dummy data
43
43
h =rand(Float32, 5) # initial hidden state
44
44
45
45
h, y =rnn(h, x)
@@ -50,7 +50,7 @@ h, y = rnn(h, x)
50
50
For the most part, we don't want to manage hidden states ourselves, but to treat our models as being stateful. Flux provides the `Recur` wrapper to do this.
51
51
52
52
```julia
53
-
x =rand(Float32, 2)
53
+
x =rand(Float32, 4)
54
54
h =rand(Float32, 5)
55
55
56
56
m = Flux.Recur(rnn, h)
@@ -60,113 +60,116 @@ y = m(x)
60
60
61
61
The `Recur` wrapper stores the state between runs in the `m.state` field.
62
62
63
-
If we use the `RNN(2, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
63
+
If we use the `RNN(4, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
64
64
65
65
```julia
66
-
julia>RNN(2, 5)
67
-
Recur(RNNCell(2, 5, tanh))
66
+
julia>RNN(4, 5)
67
+
Recur(RNNCell(4, 5, tanh))
68
68
```
69
69
70
70
Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.
71
71
72
72
Using these tools, we can now build the model shown in the above diagram with:
73
73
74
74
```julia
75
-
m =Chain(RNN(2, 5), Dense(5, 1), x ->reshape(x, :))
75
+
m =Chain(RNN(4, 5), Dense(5, 2))
76
76
```
77
+
In this example, each output has to components.
77
78
78
79
## Working with sequences
79
80
80
81
Using the previously defined `m` recurrent model, we can now apply it to a single step from our sequence:
81
82
82
83
```julia
83
-
x =rand(Float32, 2)
84
+
julia> x =rand(Float32, 4);
85
+
84
86
julia>m(x)
85
-
1-element Array{Float32,1}:
86
-
0.028398542
87
+
2-element Vector{Float32}:
88
+
-0.12852919
89
+
0.009802654
87
90
```
88
91
89
92
The `m(x)` operation would be represented by `x1 -> A -> y1` in our diagram.
90
-
If we perform this operation a second time, it will be equivalent to `x2 -> A -> y2` since the model `m` has stored the state resulting from the `x1` step:
91
-
92
-
```julia
93
-
x =rand(Float32, 2)
94
-
julia>m(x)
95
-
1-element Array{Float32,1}:
96
-
0.07381232
97
-
```
93
+
If we perform this operation a second time, it will be equivalent to `x2 -> A -> y2`
94
+
since the model `m` has stored the state resulting from the `x1` step.
98
95
99
-
Now, instead of computing a single step at a time, we can get the full `y1` to `y3` sequence in a single pass by broadcasting the model on a sequence of data.
96
+
Now, instead of computing a single step at a time, we can get the full `y1` to `y3` sequence in a single pass by
97
+
iterating the model on a sequence of data.
100
98
101
99
To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of `length = seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.
102
100
103
101
```julia
104
-
x = [rand(Float32, 2) for i =1:3]
105
-
julia>m.(x)
106
-
3-element Array{Array{Float32,1},1}:
107
-
[-0.17945863]
108
-
[-0.20863166]
109
-
[-0.20693761]
102
+
julia> x = [rand(Float32, 4) for i =1:3];
103
+
104
+
julia> [m(x[i]) for i =1:3]
105
+
3-element Vector{Vector{Float32}}:
106
+
[-0.018976994, 0.61098206]
107
+
[-0.8924057, -0.7512169]
108
+
[-0.34613007, -0.54565114]
110
109
```
111
110
112
111
If for some reason one wants to exclude the first step of the RNN chain for the computation of the loss, that can be handled with:
113
112
114
113
```julia
115
114
functionloss(x, y)
116
-
sum((Flux.stack(m.(x)[2:end],1) .- y) .^2)
115
+
m(x[1]) # ignores the output but updates the hidden states
116
+
l =0f0
117
+
for i in2:length(x)
118
+
l +=sum((m(x[i]) .- y[i-1]).^2)
119
+
end
120
+
return l
117
121
end
118
122
119
-
y =rand(Float32, 2)
120
-
julia>loss(x, y)
121
-
1.7021208968648693
123
+
y = [rand(Float32, 2) for i=1:2]
124
+
loss(x, y)
122
125
```
123
126
124
-
In such model, only `y2` and `y3` are used to compute the loss, hence the target `y` being of length 2. This is a strategy that can be used to easily handle a `seq-to-one` kind of structure, compared to the `seq-to-seq` assumed so far.
127
+
In such model, only the last two outputs are used to compute the loss, hence the target `y` being of length 2. This is a strategy that can be used to easily handle a `seq-to-one` kind of structure, compared to the `seq-to-seq` assumed so far.
125
128
126
129
Alternatively, if one wants to perform some warmup of the sequence, it could be performed once, followed with a regular training where all the steps of the sequence would be considered for the gradient update:
127
130
128
131
```julia
129
132
functionloss(x, y)
130
-
sum((Flux.stack(m.(x),1).- y) .^2)
133
+
sum(sum((m(x[i]).- y[i]).^2) for i=1:length(x))
131
134
end
132
135
133
-
seq_init = [rand(Float32, 2) for i =1:1]
134
-
seq_1 = [rand(Float32, 2) for i =1:3]
135
-
seq_2 = [rand(Float32, 2) for i =1:3]
136
+
seq_init = [rand(Float32, 4) for i =1:1]
137
+
seq_1 = [rand(Float32, 4) for i =1:3]
138
+
seq_2 = [rand(Float32, 4) for i =1:3]
136
139
137
-
y1 =rand(Float32, 3)
138
-
y2 =rand(Float32, 3)
140
+
y1 =[rand(Float32, 2) for i =1:3]
141
+
y2 =[rand(Float32, 2) for i =1:3]
139
142
140
143
X = [seq_1, seq_2]
141
144
Y = [y1, y2]
142
145
data =zip(X,Y)
143
146
144
147
Flux.reset!(m)
145
-
m.(seq_init)
148
+
[m(x) for x in seq_init]
146
149
147
150
ps =params(m)
148
151
opt=ADAM(1e-3)
149
152
Flux.train!(loss, ps, data, opt)
150
153
```
151
154
152
-
In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss (we no longer use a subset of `m.(x)` in the loss function).
155
+
In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss.
153
156
154
157
In this scenario, it is important to note that a single continuous sequence is considered. Since the model state is not reset between the 2 batches, the state of the model flows through the batches, which only makes sense in the context where `seq_1` is the continuation of `seq_init` and so on.
155
158
156
159
Batch size would be 1 here as there's only a single sequence within each batch. If the model was to be trained on multiple independent sequences, then these sequences could be added to the input data as a second dimension. For example, in a language model, each batch would contain multiple independent sentences. In such scenario, if we set the batch size to 4, a single batch would be of the shape:
157
160
158
161
```julia
159
-
batch = [rand(Float32, 2, 4) for i =1:3]
162
+
batch = [rand(Float32, 4, 4) for i =1:3]
160
163
```
161
164
162
-
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
165
+
That would mean that we have 4 sentences (or samples), each with 4 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
163
166
164
167
In many situations, such as when dealing with a language model, each batch typically contains independent sentences, so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situation, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
0 commit comments