You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/models/recurrence.md
+23-24Lines changed: 23 additions & 24 deletions
Original file line number
Diff line number
Diff line change
@@ -13,7 +13,7 @@ An aspect to recognize is that in such model, the recurrent cells `A` all refer
13
13
In the most basic RNN case, cell A could be defined by the following:
14
14
15
15
```julia
16
-
Wxh =randn(Float32, 5, 4)
16
+
Wxh =randn(Float32, 5, 2)
17
17
Whh =randn(Float32, 5, 5)
18
18
b =randn(Float32, 5)
19
19
@@ -22,7 +22,7 @@ function rnn_cell(h, x)
22
22
return h, h
23
23
end
24
24
25
-
x =rand(Float32, 4) # dummy data
25
+
x =rand(Float32, 2) # dummy data
26
26
h =rand(Float32, 5) # initial hidden state
27
27
28
28
h, y =rnn_cell(h, x)
@@ -37,9 +37,9 @@ There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCel
37
37
```julia
38
38
using Flux
39
39
40
-
rnn = Flux.RNNCell(4, 5)
40
+
rnn = Flux.RNNCell(2, 5)
41
41
42
-
x =rand(Float32, 4) # dummy data
42
+
x =rand(Float32, 2) # dummy data
43
43
h =rand(Float32, 5) # initial hidden state
44
44
45
45
h, y =rnn(h, x)
@@ -50,7 +50,7 @@ h, y = rnn(h, x)
50
50
For the most part, we don't want to manage hidden states ourselves, but to treat our models as being stateful. Flux provides the `Recur` wrapper to do this.
51
51
52
52
```julia
53
-
x =rand(Float32, 4)
53
+
x =rand(Float32, 2)
54
54
h =rand(Float32, 5)
55
55
56
56
m = Flux.Recur(rnn, h)
@@ -60,19 +60,19 @@ y = m(x)
60
60
61
61
The `Recur` wrapper stores the state between runs in the `m.state` field.
62
62
63
-
If we use the `RNN(4, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
63
+
If we use the `RNN(2, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.
64
64
65
65
```julia
66
-
julia>RNN(4, 5)
67
-
Recur(RNNCell(4, 5, tanh))
66
+
julia>RNN(2, 5)
67
+
Recur(RNNCell(2, 5, tanh))
68
68
```
69
69
70
70
Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.
71
71
72
72
Using these tools, we can now build the model shown in the above diagram with:
73
73
74
74
```julia
75
-
m =Chain(RNN(4, 5), Dense(5, 2))
75
+
m =Chain(RNN(2, 5), Dense(5, 2))
76
76
```
77
77
In this example, each output has to components.
78
78
@@ -81,7 +81,7 @@ In this example, each output has to components.
81
81
Using the previously defined `m` recurrent model, we can now apply it to a single step from our sequence:
82
82
83
83
```julia
84
-
julia> x =rand(Float32, 4);
84
+
julia> x =rand(Float32, 2);
85
85
86
86
julia>m(x)
87
87
2-element Vector{Float32}:
@@ -99,17 +99,17 @@ iterating the model on a sequence of data.
99
99
To do so, we'll need to structure the input data as a `Vector` of observations at each time step. This `Vector` will therefore be of `length = seq_length` and each of its elements will represent the input features for a given step. In our example, this translates into a `Vector` of length 3, where each element is a `Matrix` of size `(features, batch_size)`, or just a `Vector` of length `features` if dealing with a single observation.
100
100
101
101
```julia
102
-
julia> x = [rand(Float32, 4) for i =1:3];
102
+
julia> x = [rand(Float32, 2) for i =1:3];
103
103
104
-
julia> [m(x[i]) fori =1:3]
104
+
julia> [m(xi) forxi in x]
105
105
3-element Vector{Vector{Float32}}:
106
106
[-0.018976994, 0.61098206]
107
107
[-0.8924057, -0.7512169]
108
108
[-0.34613007, -0.54565114]
109
109
```
110
110
111
111
!!! warning "Use of map and broadcast"
112
-
Mapping and broadcasting operations with stateful layers such as the one we are considering are discouraged,
112
+
Mapping and broadcasting operations with stateful layers such are discouraged,
113
113
since the julia language doesn't guarantee a specific execution order.
114
114
Therefore, avoid
115
115
```julia
@@ -125,12 +125,11 @@ julia> [m(x[i]) for i = 1:3]
125
125
If for some reason one wants to exclude the first step of the RNN chain for the computation of the loss, that can be handled with:
126
126
127
127
```julia
128
+
using Flux.Losses: mse
129
+
128
130
functionloss(x, y)
129
131
m(x[1]) # ignores the output but updates the hidden states
130
-
l =0f0
131
-
for i in2:length(x)
132
-
l +=sum((m(x[i]) .- y[i-1]).^2)
133
-
end
132
+
l =sum(mse(m(xi), yi) for (xi, yi) inzip(x[2:end], y))
134
133
return l
135
134
end
136
135
@@ -144,12 +143,12 @@ Alternatively, if one wants to perform some warmup of the sequence, it could be
144
143
145
144
```julia
146
145
functionloss(x, y)
147
-
sum(sum((m(x[i]) .- y[i]).^2)fori=1:length(x))
146
+
sum(mse(m(xi), yi)for(xi, yi) inzip(x, y))
148
147
end
149
148
150
-
seq_init = [rand(Float32, 4) for i =1:1]
151
-
seq_1 = [rand(Float32, 4) for i =1:3]
152
-
seq_2 = [rand(Float32, 4) for i =1:3]
149
+
seq_init = [rand(Float32, 2)]
150
+
seq_1 = [rand(Float32, 2) for i =1:3]
151
+
seq_2 = [rand(Float32, 2) for i =1:3]
153
152
154
153
y1 = [rand(Float32, 2) for i =1:3]
155
154
y2 = [rand(Float32, 2) for i =1:3]
@@ -173,17 +172,17 @@ In this scenario, it is important to note that a single continuous sequence is c
173
172
Batch size would be 1 here as there's only a single sequence within each batch. If the model was to be trained on multiple independent sequences, then these sequences could be added to the input data as a second dimension. For example, in a language model, each batch would contain multiple independent sentences. In such scenario, if we set the batch size to 4, a single batch would be of the shape:
174
173
175
174
```julia
176
-
batch = [rand(Float32, 4, 4) for i =1:3]
175
+
batch = [rand(Float32, 2, 4) for i =1:3]
177
176
```
178
177
179
-
That would mean that we have 4 sentences (or samples), each with 4 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
178
+
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
180
179
181
180
In many situations, such as when dealing with a language model, each batch typically contains independent sentences, so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situation, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
0 commit comments