@@ -146,13 +146,13 @@ testmode!(m::AlphaDropout, mode=true) =
146
146
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
147
147
148
148
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
149
- used with recurrent hidden states.
150
- The argument `sz` should be an integer or a tuple of integers.
151
- In the forward pass, the layer normalises the mean and standard
149
+ used with recurrent hidden states.
150
+ The argument `sz` should be an integer or a tuple of integers.
151
+ In the forward pass, the layer normalises the mean and standard
152
152
deviation of the input, the applied the elementwise activation `λ`.
153
153
The input is normalised along the first `length(sz)` dimensions
154
154
for tuple `sz`, along the first dimension for integer `sz`.
155
- The input is expected to have first dimensions' size equal to `sz`.
155
+ The input is expected to have first dimensions' size equal to `sz`.
156
156
157
157
If `affine=true` also applies a learnable shift and rescaling
158
158
as in the [`Diagonal`](@ref) layer.
@@ -192,35 +192,48 @@ end
192
192
# Compute the statistics on the slices specified by reduce_dims.
193
193
# reduce_dims=[1,...,N-2,N] for BatchNorm
194
194
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
195
- function _norm_layer_forward (l, x:: AbstractArray{T,N} ; reduce_dims, affine_shape) where {T, N}
195
+ function _norm_layer_forward (
196
+ l, x:: AbstractArray{T, N} ; reduce_dims, affine_shape,
197
+ ) where {T, N}
196
198
if ! _isactive (l) && l. track_stats # testmode with tracked stats
197
199
stats_shape = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
198
200
μ = reshape (l. μ, stats_shape)
199
201
σ² = reshape (l. σ², stats_shape)
200
- else # trainmode or testmode without tracked stats
202
+ else # trainmode or testmode without tracked stats
201
203
μ = mean (x; dims= reduce_dims)
202
204
σ² = mean ((x .- μ). ^ 2 ; dims= reduce_dims)
203
205
if l. track_stats
204
- # # update moving mean/std
205
- Zygote. ignore () do
206
- mtm = l. momentum
207
- m = prod (size (x, i) for i in reduce_dims) # needed for computing corrected var
208
- μnew = vec (N ∈ reduce_dims ? μ : mean (μ, dims= N))
209
- σ²new = vec (N ∈ reduce_dims ? σ² : mean (σ², dims= N))
210
- l. μ = (1 - mtm) .* l. μ .+ mtm .* μnew
211
- l. σ² = (1 - mtm) .* l. σ² .+ mtm .* (m / (m - one (eltype (l. σ²)))) .* σ²new
212
- end
206
+ _track_stats! (l, x, μ, σ², reduce_dims) # update moving mean/std
213
207
end
214
208
end
215
- if hasaffine (l)
216
- γ = reshape (l . γ, affine_shape )
217
- β = reshape (l . β, affine_shape )
218
- return l . λ .(γ .* (x .- μ) ./ sqrt .(σ² .+ l . ϵ) .+ β)
219
- else
220
- return l . λ .((x .- μ) ./ sqrt .(σ² .+ l . ϵ) )
221
- end
209
+
210
+ o = _norm_layer_forward (x, μ, σ², l . ϵ )
211
+ hasaffine (l) || return l . λ .(o )
212
+
213
+ γ = reshape (l . γ, affine_shape)
214
+ β = reshape (l . β, affine_shape )
215
+ return l . λ .(γ .* o .+ β)
222
216
end
223
217
218
+ @inline _norm_layer_forward (x, μ, σ², ϵ) = (x .- μ) ./ sqrt .(σ² .+ ϵ)
219
+
220
+ function _track_stats! (
221
+ bn, x:: AbstractArray{T, N} , μ, σ², reduce_dims,
222
+ ) where {T, N}
223
+ V = eltype (bn. σ²)
224
+ mtm = bn. momentum
225
+ res_mtm = one (V) - mtm
226
+ m = prod (size (x, i) for i in reduce_dims)
227
+
228
+ μnew = vec (N ∈ reduce_dims ? μ : mean (μ, dims= N))
229
+ σ²new = vec (N ∈ reduce_dims ? σ² : mean (σ², dims= N))
230
+
231
+ bn. μ = res_mtm .* bn. μ .+ mtm .* μnew
232
+ bn. σ² = res_mtm .* bn. σ² .+ mtm .* (m / (m - one (V))) .* σ²new
233
+ return nothing
234
+ end
235
+ Zygote. @nograd _track_stats!
236
+
224
237
"""
225
238
BatchNorm(channels::Integer, λ=identity;
226
239
initβ=zeros32, initγ=ones32,
@@ -234,15 +247,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
234
247
a batch of feature vectors this is just the data dimension, for `WHCN` images
235
248
it's the usual channel dimension.
236
249
237
- `BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
250
+ `BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
238
251
input slice and normalises the input accordingly.
239
252
240
- If `affine=true`, it also applies a shift and a rescale to the input
253
+ If `affine=true`, it also applies a shift and a rescale to the input
241
254
through to learnable per-channel bias β and scale γ parameters.
242
255
243
- After normalisation, elementwise activation `λ` is applied.
256
+ After normalisation, elementwise activation `λ` is applied.
244
257
245
- If `track_stats=true`, accumulates mean and var statistics in training phase
258
+ If `track_stats=true`, accumulates mean and var statistics in training phase
246
259
that will be used to renormalize the input in test phase.
247
260
248
261
Use [`testmode!`](@ref) during inference.
@@ -272,7 +285,7 @@ mutable struct BatchNorm{F,V,N,W}
272
285
end
273
286
274
287
function BatchNorm (chs:: Int , λ= identity;
275
- initβ= zeros32, initγ= ones32,
288
+ initβ= zeros32, initγ= ones32,
276
289
affine= true , track_stats= true ,
277
290
ϵ= 1f-5 , momentum= 0.1f0 )
278
291
@@ -282,8 +295,8 @@ function BatchNorm(chs::Int, λ=identity;
282
295
σ² = track_stats ? ones32 (chs) : nothing
283
296
284
297
return BatchNorm (λ, β, γ,
285
- μ, σ², ϵ, momentum,
286
- affine, track_stats,
298
+ μ, σ², ϵ, momentum,
299
+ affine, track_stats,
287
300
nothing , chs)
288
301
end
289
302
@@ -318,19 +331,19 @@ end
318
331
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
319
332
`channels` should be the size of the channel dimension in your data (see below).
320
333
321
- Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
334
+ Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
322
335
For `WHCN` images it's the usual channel dimension.
323
336
324
- `InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
337
+ `InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
325
338
input slice and normalises the input accordingly.
326
339
327
- If `affine=true`, it also applies a shift and a rescale to the input
340
+ If `affine=true`, it also applies a shift and a rescale to the input
328
341
through to learnable per-channel bias `β` and scale `γ` parameters.
329
342
330
- If `track_stats=true`, accumulates mean and var statistics in training phase
343
+ If `track_stats=true`, accumulates mean and var statistics in training phase
331
344
that will be used to renormalize the input in test phase.
332
345
333
- **Warning**: the defaults for `affine` and `track_stats` used to be `true`
346
+ **Warning**: the defaults for `affine` and `track_stats` used to be `true`
334
347
in previous Flux versions (< v0.12).
335
348
"""
336
349
mutable struct InstanceNorm{F,V,N,W}
@@ -358,7 +371,7 @@ function InstanceNorm(chs::Int, λ=identity;
358
371
σ² = track_stats ? ones32 (chs) : nothing
359
372
360
373
return InstanceNorm (λ, β, γ,
361
- μ, σ², ϵ, momentum,
374
+ μ, σ², ϵ, momentum,
362
375
affine, track_stats,
363
376
nothing , chs)
364
377
end
@@ -401,13 +414,13 @@ The number of channels must be an integer multiple of the number of groups.
401
414
402
415
`channels` should be the size of the channel dimension in your data (see below).
403
416
404
- Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
417
+ Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
405
418
For `WHCN` images it's the usual channel dimension.
406
419
407
- If `affine=true`, it also applies a shift and a rescale to the input
420
+ If `affine=true`, it also applies a shift and a rescale to the input
408
421
through to learnable per-channel bias `β` and scale `γ` parameters.
409
422
410
- If `track_stats=true`, accumulates mean and var statistics in training phase
423
+ If `track_stats=true`, accumulates mean and var statistics in training phase
411
424
that will be used to renormalize the input in test phase.
412
425
"""
413
426
mutable struct GroupNorm{F,V,N,W}
429
442
trainable (gn:: GroupNorm ) = hasaffine (gn) ? (gn. β, gn. γ) : ()
430
443
431
444
function GroupNorm (chs:: Int , G:: Int , λ= identity;
432
- initβ= zeros32, initγ= ones32,
445
+ initβ= zeros32, initγ= ones32,
433
446
affine= true , track_stats= false ,
434
447
ϵ= 1f-5 , momentum= 0.1f0 )
435
448
@@ -440,11 +453,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
440
453
μ = track_stats ? zeros32 (G) : nothing
441
454
σ² = track_stats ? ones32 (G) : nothing
442
455
443
- return GroupNorm (G, λ,
456
+ return GroupNorm (G, λ,
444
457
β, γ,
445
- μ, σ²,
446
- ϵ, momentum,
447
- affine, track_stats,
458
+ μ, σ²,
459
+ ϵ, momentum,
460
+ affine, track_stats,
448
461
nothing , chs)
449
462
end
450
463
475
488
"""
476
489
hasaffine(l)
477
490
478
- Return `true` if a normalisation layer has trainable shift and
491
+ Return `true` if a normalisation layer has trainable shift and
479
492
scale parameters, `false` otherwise.
480
493
481
494
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
0 commit comments