@@ -142,33 +142,33 @@ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
142
142
(5, 10)
143
143
```
144
144
145
- The following is a demonstration of when failing to call `reset!` between batch size changes can cause erroneous outputs.
145
+ !!! warning Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the following example:
146
146
147
- ```julia
148
- julia> r = RNN(3, 5)
149
- Recur(
150
- RNNCell(3, 5, tanh), # 50 parameters
151
- ) # Total: 4 trainable arrays, 50 parameters,
152
- # plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
147
+ ```julia
148
+ julia> r = RNN(3, 5)
149
+ Recur(
150
+ RNNCell(3, 5, tanh), # 50 parameters
151
+ ) # Total: 4 trainable arrays, 50 parameters,
152
+ # plus 1 non-trainable, 5 parameters, summarysize 432 bytes.
153
153
154
- julia> r.state |> size
155
- (5, 1)
154
+ julia> r.state |> size
155
+ (5, 1)
156
156
157
- julia> r(rand(Float32, 3)) |> size
158
- (5,)
157
+ julia> r(rand(Float32, 3)) |> size
158
+ (5,)
159
159
160
- julia> r.state |> size
161
- (5, 1)
160
+ julia> r.state |> size
161
+ (5, 1)
162
162
163
- julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
164
- (5, 10)
163
+ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
164
+ (5, 10)
165
165
166
- julia> r.state |> size # state shape has changed
167
- (5, 10)
166
+ julia> r.state |> size # state shape has changed
167
+ (5, 10)
168
168
169
- julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
170
- (50,)
171
- ```
169
+ julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
170
+ (50,)
171
+ ```
172
172
"""
173
173
RNN (a... ; ka... ) = Recur (RNNCell (a... ; ka... ))
174
174
Recur (m:: RNNCell ) = Recur (m, m. state0)
0 commit comments