Skip to content

Commit 85c17ee

Browse files
Merge #1763
1763: Doc update (recurrence.md): fixed incorrect output dimensions, clarified batching. r=CarloLucibello a=NightMachinary Co-authored-by: NightMachinary <rudiwillalwaysloveyou@gmail.com> Co-authored-by: NightMachinary <36224762+NightMachinary@users.noreply.github.com>
2 parents 4e28377 + 267b115 commit 85c17ee

File tree

1 file changed

+45
-14
lines changed

1 file changed

+45
-14
lines changed

docs/src/models/recurrence.md

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,24 @@ To introduce Flux's recurrence functionalities, we will consider the following v
88

99
In the above, we have a sequence of length 3, where `x1` to `x3` represent the input at each step (could be a timestamp or a word in a sentence), and `y1` to `y3` are their respective outputs.
1010

11-
An aspect to recognize is that in such model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a simple dense layer is that the cell `A` is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).
11+
An aspect to recognize is that in such a model, the recurrent cells `A` all refer to the same structure. What distinguishes it from a simple dense layer is that the cell `A` is fed, in addition to an input `x`, with information from the previous state of the model (hidden state denoted as `h1` & `h2` in the diagram).
1212

1313
In the most basic RNN case, cell A could be defined by the following:
1414

1515
```julia
16-
Wxh = randn(Float32, 5, 2)
17-
Whh = randn(Float32, 5, 5)
18-
b = randn(Float32, 5)
16+
output_size = 5
17+
input_size = 2
18+
Wxh = randn(Float32, output_size, input_size)
19+
Whh = randn(Float32, output_size, output_size)
20+
b = randn(Float32, output_size)
1921

2022
function rnn_cell(h, x)
2123
h = tanh.(Wxh * x .+ Whh * h .+ b)
2224
return h, h
2325
end
2426

25-
x = rand(Float32, 2) # dummy data
26-
h = rand(Float32, 5) # initial hidden state
27+
x = rand(Float32, input_size) # dummy input data
28+
h = rand(Float32, output_size) # random initial hidden state
2729

2830
h, y = rnn_cell(h, x)
2931
```
@@ -84,9 +86,8 @@ Using the previously defined `m` recurrent model, we can now apply it to a singl
8486
julia> x = rand(Float32, 2);
8587

8688
julia> m(x)
87-
2-element Vector{Float32}:
88-
-0.12852919
89-
0.009802654
89+
1-element Vector{Float32}:
90+
0.31759313
9091
```
9192

9293
The `m(x)` operation would be represented by `x1 -> A -> y1` in our diagram.
@@ -103,9 +104,9 @@ julia> x = [rand(Float32, 2) for i = 1:3];
103104

104105
julia> [m(xi) for xi in x]
105106
3-element Vector{Vector{Float32}}:
106-
[-0.018976994, 0.61098206]
107-
[-0.8924057, -0.7512169]
108-
[-0.34613007, -0.54565114]
107+
[-0.033448644]
108+
[0.5259023]
109+
[-0.11183384]
109110
```
110111

111112
!!! warning "Use of map and broadcast"
@@ -175,9 +176,39 @@ x = [rand(Float32, 2, 4) for i = 1:3]
175176
y = [rand(Float32, 1, 4) for i = 1:3]
176177
```
177178

178-
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix).
179+
That would mean that we have 4 sentences (or samples), each with 2 features (let's say a very small embedding!) and each with a length of 3 (3 words per sentence). Computing `m(batch[1])`, would still represent `x1 -> y1` in our diagram and returns the first word output, but now for each of the 4 independent sentences (second dimension of the input matrix). We do not need to use `Flux.reset!(m)` here; each sentence in the batch will output in its own "column", and the outputs of the different sentences won't mix.
180+
181+
To illustrate, we go through an example of batching with our implementation of `rnn_cell`. The implementation doesn't need to change; the batching comes for "free" from the way Julia does broadcasting and the rules of matrix multiplication.
182+
183+
```julia
184+
output_size = 5
185+
input_size = 2
186+
Wxh = randn(Float32, output_size, input_size)
187+
Whh = randn(Float32, output_size, output_size)
188+
b = randn(Float32, output_size)
189+
190+
function rnn_cell(h, x)
191+
h = tanh.(Wxh * x .+ Whh * h .+ b)
192+
return h, h
193+
end
194+
```
195+
196+
Here, we use the last dimension of the input and the hidden state as the batch dimension. I.e., `h[:, n]` would be the hidden state of the nth sentence in the batch.
197+
198+
```julia
199+
batch_size = 4
200+
x = rand(Float32, input_size, batch_size) # dummy input data
201+
h = rand(Float32, output_size, batch_size) # random initial hidden state
202+
203+
h, y = rnn_cell(h, x)
204+
```
205+
206+
```julia
207+
julia> size(h) == size(y) == (output_size, batch_size)
208+
true
209+
```
179210

180-
In many situations, such as when dealing with a language model, each batch typically contains independent sentences, so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situation, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
211+
In many situations, such as when dealing with a language model, the sentences in each batch are independent (i.e. the last item of the first sentence of the first batch is independent from the first item of the first sentence of the second batch), so we cannot handle the model as if each batch was the direct continuation of the previous one. To handle such situations, we need to reset the state of the model between each batch, which can be conveniently performed within the loss function:
181212

182213
```julia
183214
function loss(x, y)

0 commit comments

Comments
 (0)