You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/models/recurrence.md
+45-14Lines changed: 45 additions & 14 deletions
Original file line number
Diff line number
Diff line change
@@ -8,22 +8,24 @@ To introduce Flux's recurrence functionalities, we will consider the following v
8
8
9
9
In the above, we have a sequence of length 3, where `x1` to `x3` represent the input at each step (could be a timestamp or a word in a sentence), and `y1` to `y3` are their respective outputs.
10
10
11
-
An aspect to recognize is that in such model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a simple dense layer is that the cell `A` is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).
11
+
An aspect to recognize is that in such a model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a simple dense layer is that the cell `A` is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).
12
12
13
13
In the most basic RNN case, cell A could be defined by the following:
14
14
15
15
```julia
16
-
Wxh =randn(Float32, 5, 2)
17
-
Whh =randn(Float32, 5, 5)
18
-
b =randn(Float32, 5)
16
+
output_size =5
17
+
input_size =2
18
+
Wxh =randn(Float32, output_size, input_size)
19
+
Whh =randn(Float32, output_size, output_size)
20
+
b =randn(Float32, output_size)
19
21
20
22
functionrnn_cell(h, x)
21
23
h =tanh.(Wxh * x .+ Whh * h .+ b)
22
24
return h, h
23
25
end
24
26
25
-
x =rand(Float32, 2) # dummy data
26
-
h =rand(Float32, 5) # initial hidden state
27
+
x =rand(Float32, input_size) # dummy input data
28
+
h =rand(Float32, output_size) # random initial hidden state
27
29
28
30
h, y =rnn_cell(h, x)
29
31
```
@@ -84,9 +86,8 @@ Using the previously defined `m` recurrent model, we can now apply it to a singl
84
86
julia> x =rand(Float32, 2);
85
87
86
88
julia>m(x)
87
-
2-element Vector{Float32}:
88
-
-0.12852919
89
-
0.009802654
89
+
1-element Vector{Float32}:
90
+
0.31759313
90
91
```
91
92
92
93
The `m(x)` operation would be represented by `x1 -> A -> y1` in our diagram.
@@ -103,9 +104,9 @@ julia> x = [rand(Float32, 2) for i = 1:3];
103
104
104
105
julia> [m(xi) for xi in x]
105
106
3-element Vector{Vector{Float32}}:
106
-
[-0.018976994, 0.61098206]
107
-
[-0.8924057, -0.7512169]
108
-
[-0.34613007, -0.54565114]
107
+
[-0.033448644]
108
+
[0.5259023]
109
+
[-0.11183384]
109
110
```
110
111
111
112
!!! warning "Use of map and broadcast"
@@ -175,9 +176,39 @@ x = [rand(Float32, 2, 4) for i = 1:3]
175
176
y = [rand(Float32, 1, 4) for i =1:3]
176
177
```
177
178
178
-
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
179
+
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix). We do not need to use `Flux.reset!(m)` here; each sentence in the batch will output in its own "column", and the outputs of the different sentences won't mix.
180
+
181
+
To illustrate, we go through an example of batching with our implementation of `rnn_cell`. The implementation doesn't need to change; the batching comes for "free" from the way Julia does broadcasting and the rules of matrix multiplication.
182
+
183
+
```julia
184
+
output_size =5
185
+
input_size =2
186
+
Wxh =randn(Float32, output_size, input_size)
187
+
Whh =randn(Float32, output_size, output_size)
188
+
b =randn(Float32, output_size)
189
+
190
+
functionrnn_cell(h, x)
191
+
h =tanh.(Wxh * x .+ Whh * h .+ b)
192
+
return h, h
193
+
end
194
+
```
195
+
196
+
Here, we use the last dimension of the input and the hidden state as the batch dimension. I.e., `h[:, n]` would be the hidden state of the nth sentence in the batch.
197
+
198
+
```julia
199
+
batch_size =4
200
+
x =rand(Float32, input_size, batch_size) # dummy input data
201
+
h =rand(Float32, output_size, batch_size) # random initial hidden state
In many situations, such as when dealing with a language model, each batch typically contains independent sentences, so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situation, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
211
+
In many situations, such as when dealing with a language model, the sentences in each batch are independent (i.e. the last item of the first sentence of the first batch is independent from the first item of the first sentence of the second batch), so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situations, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
0 commit comments