@@ -122,13 +122,13 @@ testmode!(m::AlphaDropout, mode=true) =
122
122
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
123
123
124
124
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
125
- used with recurrent hidden states.
126
- The argument `sz` should be an integer or a tuple of integers.
127
- In the forward pass, the layer normalises the mean and standard
125
+ used with recurrent hidden states.
126
+ The argument `sz` should be an integer or a tuple of integers.
127
+ In the forward pass, the layer normalises the mean and standard
128
128
deviation of the input, the applied the elementwise activation `λ`.
129
129
The input is normalised along the first `length(sz)` dimensions
130
130
for tuple `sz`, along the first dimension for integer `sz`.
131
- The input is expected to have first dimensions' size equal to `sz`.
131
+ The input is expected to have first dimensions' size equal to `sz`.
132
132
133
133
If `affine=true` also applies a learnable shift and rescaling
134
134
as in the [`Diagonal`](@ref) layer.
@@ -164,38 +164,70 @@ function Base.show(io::IO, l::LayerNorm)
164
164
print (io, " )" )
165
165
end
166
166
167
+ _maybe_promote_type (:: Type{T1} , :: Type{T2} ) where {T1, T2} = promote_type (T1, T2)
168
+ _maybe_promote_type (:: Type{Nothing} , :: Type{T2} ) where T2 = T2
169
+ _maybe_promote_type (:: Type{T1} , :: Type{Nothing} ) where T1 = T1
170
+
171
+ _maybe_eltype (:: Type{T} ) where T <: AbstractArray = eltype (T)
172
+ _maybe_eltype (:: Type{Nothing} ) = Nothing
173
+
174
+ abstract type Normalization{F, V, N, W} end
175
+
176
+ function _promote_to_output (
177
+ :: Normalization{F, V, N, W} , x:: AbstractArray{T} ,
178
+ ) where {F, V, N, W, T}
179
+ Vel = _maybe_eltype (V)
180
+ Wel = _maybe_eltype (W)
181
+ _maybe_promote_type (_maybe_promote_type (
182
+ _maybe_promote_type (T, Vel), N), Wel)
183
+ end
184
+
167
185
# For InstanceNorm, GroupNorm, and BatchNorm.
168
186
# Compute the statistics on the slices specified by reduce_dims.
169
187
# reduce_dims=[1,...,N-2,N] for BatchNorm
170
188
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
171
- function _norm_layer_forward (l, x:: AbstractArray{T,N} ; reduce_dims, affine_shape) where {T, N}
189
+ _norm_layer_forward (l, x; reduce_dims, affine_shape) =
190
+ _norm_layer_forward (l, x, _promote_to_output (l, x); reduce_dims, affine_shape)
191
+
192
+ function _norm_layer_forward (
193
+ l, x:: Array{T, N} , :: Type{O} ; reduce_dims, affine_shape,
194
+ ) where {T, N, O}
172
195
if ! _isactive (l) && l. track_stats # testmode with tracked stats
173
196
stats_shape = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
174
197
μ = reshape (l. μ, stats_shape)
175
198
σ² = reshape (l. σ², stats_shape)
176
- else # trainmode or testmode without tracked stats
199
+ else # trainmode or testmode without tracked stats
177
200
μ = mean (x; dims= reduce_dims)
178
201
σ² = mean ((x .- μ). ^ 2 ; dims= reduce_dims)
179
202
if l. track_stats
180
- # # update moving mean/std
181
- Zygote. ignore () do
182
- mtm = l. momentum
183
- m = prod (size (x, i) for i in reduce_dims) # needed for computing corrected var
184
- μnew = vec (N ∈ reduce_dims ? μ : mean (μ, dims= N))
185
- σ²new = vec (N ∈ reduce_dims ? σ² : mean (σ², dims= N))
186
- l. μ = (1 - mtm) .* l. μ .+ mtm .* μnew
187
- l. σ² = (1 - mtm) .* l. σ² .+ mtm .* (m / (m - one (eltype (l. σ²)))) .* σ²new
188
- end
203
+ _track_stats! (l, x, μ, σ², reduce_dims) # update moving mean/std
189
204
end
190
205
end
191
- if hasaffine (l)
192
- γ = reshape (l. γ, affine_shape)
193
- β = reshape (l. β, affine_shape)
194
- return l. λ .(γ .* (x .- μ) ./ sqrt .(σ² .+ l. ϵ) .+ β)
195
- else
196
- return l. λ .((x .- μ) ./ sqrt .(σ² .+ l. ϵ))
197
- end
206
+
207
+ o:: Array{O, N} = ((x .- μ) ./ sqrt .(σ² .+ l. ϵ))
208
+ hasaffine (l) || return l. λ .(o)
209
+
210
+ γ = reshape (l. γ, affine_shape)
211
+ β = reshape (l. β, affine_shape)
212
+ return l. λ .(γ .* o .+ β)
213
+ end
214
+
215
+ function _track_stats! (
216
+ bn, x:: AbstractArray{T, N} , μ, σ², reduce_dims,
217
+ ) where {T, N}
218
+ V = eltype (bn. σ²)
219
+ mtm = bn. momentum
220
+ res_mtm = one (V) - mtm
221
+ m = prod (size (x, i) for i in reduce_dims)
222
+
223
+ μnew = vec (N ∈ reduce_dims ? μ : mean (μ, dims= N))
224
+ σ²new = vec (N ∈ reduce_dims ? σ² : mean (σ², dims= N))
225
+
226
+ bn. μ = res_mtm .* bn. μ .+ mtm .* μnew
227
+ bn. σ² = res_mtm .* bn. σ² .+ mtm .* (m / (m - one (V))) .* σ²new
228
+ nothing
198
229
end
230
+ Zygote. @nograd _track_stats!
199
231
200
232
"""
201
233
BatchNorm(channels::Integer, λ=identity;
@@ -210,15 +242,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
210
242
a batch of feature vectors this is just the data dimension, for `WHCN` images
211
243
it's the usual channel dimension.
212
244
213
- `BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
245
+ `BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
214
246
input slice and normalises the input accordingly.
215
247
216
- If `affine=true`, it also applies a shift and a rescale to the input
248
+ If `affine=true`, it also applies a shift and a rescale to the input
217
249
through to learnable per-channel bias β and scale γ parameters.
218
250
219
- After normalisation, elementwise activation `λ` is applied.
251
+ After normalisation, elementwise activation `λ` is applied.
220
252
221
- If `track_stats=true`, accumulates mean and var statistics in training phase
253
+ If `track_stats=true`, accumulates mean and var statistics in training phase
222
254
that will be used to renormalize the input in test phase.
223
255
224
256
Use [`testmode!`](@ref) during inference.
@@ -233,7 +265,7 @@ m = Chain(
233
265
softmax)
234
266
```
235
267
"""
236
- mutable struct BatchNorm{F,V,N,W}
268
+ mutable struct BatchNorm{F,V,N,W} <: Normalization{F, V, N, W}
237
269
λ:: F # activation function
238
270
β:: V # bias
239
271
γ:: V # scale
@@ -248,7 +280,7 @@ mutable struct BatchNorm{F,V,N,W}
248
280
end
249
281
250
282
function BatchNorm (chs:: Int , λ= identity;
251
- initβ= zeros32, initγ= ones32,
283
+ initβ= zeros32, initγ= ones32,
252
284
affine= true , track_stats= true ,
253
285
ϵ= 1f-5 , momentum= 0.1f0 )
254
286
@@ -258,8 +290,8 @@ function BatchNorm(chs::Int, λ=identity;
258
290
σ² = track_stats ? ones32 (chs) : nothing
259
291
260
292
return BatchNorm (λ, β, γ,
261
- μ, σ², ϵ, momentum,
262
- affine, track_stats,
293
+ μ, σ², ϵ, momentum,
294
+ affine, track_stats,
263
295
nothing , chs)
264
296
end
265
297
@@ -294,22 +326,22 @@ end
294
326
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
295
327
`channels` should be the size of the channel dimension in your data (see below).
296
328
297
- Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
329
+ Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
298
330
For `WHCN` images it's the usual channel dimension.
299
331
300
- `InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
332
+ `InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
301
333
input slice and normalises the input accordingly.
302
334
303
- If `affine=true`, it also applies a shift and a rescale to the input
335
+ If `affine=true`, it also applies a shift and a rescale to the input
304
336
through to learnable per-channel bias `β` and scale `γ` parameters.
305
337
306
- If `track_stats=true`, accumulates mean and var statistics in training phase
338
+ If `track_stats=true`, accumulates mean and var statistics in training phase
307
339
that will be used to renormalize the input in test phase.
308
340
309
- **Warning**: the defaults for `affine` and `track_stats` used to be `true`
341
+ **Warning**: the defaults for `affine` and `track_stats` used to be `true`
310
342
in previous Flux versions (< v0.12).
311
343
"""
312
- mutable struct InstanceNorm{F,V,N,W}
344
+ mutable struct InstanceNorm{F,V,N,W} <: Normalization{F, V, N, W}
313
345
λ:: F # activation function
314
346
β:: V # bias
315
347
γ:: V # scale
@@ -334,7 +366,7 @@ function InstanceNorm(chs::Int, λ=identity;
334
366
σ² = track_stats ? ones32 (chs) : nothing
335
367
336
368
return InstanceNorm (λ, β, γ,
337
- μ, σ², ϵ, momentum,
369
+ μ, σ², ϵ, momentum,
338
370
affine, track_stats,
339
371
nothing , chs)
340
372
end
@@ -377,16 +409,16 @@ The number of channels must be an integer multiple of the number of groups.
377
409
378
410
`channels` should be the size of the channel dimension in your data (see below).
379
411
380
- Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
412
+ Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
381
413
For `WHCN` images it's the usual channel dimension.
382
414
383
- If `affine=true`, it also applies a shift and a rescale to the input
415
+ If `affine=true`, it also applies a shift and a rescale to the input
384
416
through to learnable per-channel bias `β` and scale `γ` parameters.
385
417
386
- If `track_stats=true`, accumulates mean and var statistics in training phase
418
+ If `track_stats=true`, accumulates mean and var statistics in training phase
387
419
that will be used to renormalize the input in test phase.
388
420
"""
389
- mutable struct GroupNorm{F,V,N,W}
421
+ mutable struct GroupNorm{F,V,N,W} <: Normalization{F, V, N, W}
390
422
G:: Int # number of groups
391
423
λ:: F # activation function
392
424
β:: V # bias
405
437
trainable (gn:: GroupNorm ) = hasaffine (gn) ? (gn. β, gn. γ) : ()
406
438
407
439
function GroupNorm (chs:: Int , G:: Int , λ= identity;
408
- initβ= zeros32, initγ= ones32,
440
+ initβ= zeros32, initγ= ones32,
409
441
affine= true , track_stats= false ,
410
442
ϵ= 1f-5 , momentum= 0.1f0 )
411
443
@@ -416,11 +448,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
416
448
μ = track_stats ? zeros32 (G) : nothing
417
449
σ² = track_stats ? ones32 (G) : nothing
418
450
419
- return GroupNorm (G, λ,
451
+ return GroupNorm (G, λ,
420
452
β, γ,
421
- μ, σ²,
422
- ϵ, momentum,
423
- affine, track_stats,
453
+ μ, σ²,
454
+ ϵ, momentum,
455
+ affine, track_stats,
424
456
nothing , chs)
425
457
end
426
458
451
483
"""
452
484
hasaffine(l)
453
485
454
- Return `true` if a normalisation layer has trainable shift and
486
+ Return `true` if a normalisation layer has trainable shift and
455
487
scale parameters, `false` otherwise.
456
488
457
489
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
0 commit comments