@@ -146,13 +146,13 @@ testmode!(m::AlphaDropout, mode=true) =
146
146
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
147
147
148
148
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
149
- used with recurrent hidden states.
150
- The argument `sz` should be an integer or a tuple of integers.
151
- In the forward pass, the layer normalises the mean and standard
149
+ used with recurrent hidden states.
150
+ The argument `sz` should be an integer or a tuple of integers.
151
+ In the forward pass, the layer normalises the mean and standard
152
152
deviation of the input, the applied the elementwise activation `λ`.
153
153
The input is normalised along the first `length(sz)` dimensions
154
154
for tuple `sz`, along the first dimension for integer `sz`.
155
- The input is expected to have first dimensions' size equal to `sz`.
155
+ The input is expected to have first dimensions' size equal to `sz`.
156
156
157
157
If `affine=true` also applies a learnable shift and rescaling
158
158
as in the [`Diagonal`](@ref) layer.
@@ -188,39 +188,78 @@ function Base.show(io::IO, l::LayerNorm)
188
188
print (io, " )" )
189
189
end
190
190
191
+ _maybe_promote_type (:: Type{T1} , :: Type{T2} ) where {T1, T2} = promote_type (T1, T2)
192
+ _maybe_promote_type (:: Type{Nothing} , :: Type{T2} ) where T2 = T2
193
+ _maybe_promote_type (:: Type{T1} , :: Type{Nothing} ) where T1 = T1
194
+
195
+ _maybe_eltype (:: Type{T} ) where T <: AbstractArray = eltype (T)
196
+ _maybe_eltype (:: Type{Nothing} ) = Nothing
197
+
198
+ abstract type Normalization{F, V, N, W} end
199
+
200
+ function _promote_to_output (
201
+ :: Normalization{F, V, N, W} , x:: AbstractArray{T} ,
202
+ ) where {F, V, N, W, T}
203
+ Vel = _maybe_eltype (V)
204
+ Wel = _maybe_eltype (W)
205
+ _maybe_promote_type (_maybe_promote_type (
206
+ _maybe_promote_type (T, Vel), N), Wel)
207
+ end
208
+
209
+ function _basetype (:: Type{T} ) where T
210
+ if T <: Array
211
+ return Array
212
+ elseif T <: CuArray
213
+ return CuArray
214
+ end
215
+ throw (" Unsupported type $T " )
216
+ end
217
+
191
218
# For InstanceNorm, GroupNorm, and BatchNorm.
192
219
# Compute the statistics on the slices specified by reduce_dims.
193
220
# reduce_dims=[1,...,N-2,N] for BatchNorm
194
221
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
195
- function _norm_layer_forward (l, x:: AbstractArray{T,N} ; reduce_dims, affine_shape) where {T, N}
222
+ function _norm_layer_forward (
223
+ l, x:: AbstractArray{T, N} ; reduce_dims, affine_shape,
224
+ ) where {T, N}
196
225
if ! _isactive (l) && l. track_stats # testmode with tracked stats
197
226
stats_shape = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
198
227
μ = reshape (l. μ, stats_shape)
199
228
σ² = reshape (l. σ², stats_shape)
200
- else # trainmode or testmode without tracked stats
229
+ else # trainmode or testmode without tracked stats
201
230
μ = mean (x; dims= reduce_dims)
202
231
σ² = mean ((x .- μ). ^ 2 ; dims= reduce_dims)
203
232
if l. track_stats
204
- # # update moving mean/std
205
- Zygote. ignore () do
206
- mtm = l. momentum
207
- m = prod (size (x, i) for i in reduce_dims) # needed for computing corrected var
208
- μnew = vec (N ∈ reduce_dims ? μ : mean (μ, dims= N))
209
- σ²new = vec (N ∈ reduce_dims ? σ² : mean (σ², dims= N))
210
- l. μ = (1 - mtm) .* l. μ .+ mtm .* μnew
211
- l. σ² = (1 - mtm) .* l. σ² .+ mtm .* (m / (m - one (eltype (l. σ²)))) .* σ²new
212
- end
233
+ _track_stats! (l, x, μ, σ², reduce_dims) # update moving mean/std
213
234
end
214
235
end
215
- if hasaffine (l)
216
- γ = reshape (l. γ, affine_shape)
217
- β = reshape (l. β, affine_shape)
218
- return l. λ .(γ .* (x .- μ) ./ sqrt .(σ² .+ l. ϵ) .+ β)
219
- else
220
- return l. λ .((x .- μ) ./ sqrt .(σ² .+ l. ϵ))
221
- end
236
+
237
+ O = _promote_to_output (l, x)
238
+ o:: _basetype (typeof (x)){O, N} = ((x .- μ) ./ sqrt .(σ² .+ l. ϵ))
239
+ hasaffine (l) || return l. λ .(o)
240
+
241
+ γ = reshape (l. γ, affine_shape)
242
+ β = reshape (l. β, affine_shape)
243
+ return l. λ .(γ .* o .+ β)
222
244
end
223
245
246
+ function _track_stats! (
247
+ bn, x:: AbstractArray{T, N} , μ, σ², reduce_dims,
248
+ ) where {T, N}
249
+ V = eltype (bn. σ²)
250
+ mtm = bn. momentum
251
+ res_mtm = one (V) - mtm
252
+ m = prod (size (x, i) for i in reduce_dims)
253
+
254
+ μnew = vec (N ∈ reduce_dims ? μ : mean (μ, dims= N))
255
+ σ²new = vec (N ∈ reduce_dims ? σ² : mean (σ², dims= N))
256
+
257
+ bn. μ = res_mtm .* bn. μ .+ mtm .* μnew
258
+ bn. σ² = res_mtm .* bn. σ² .+ mtm .* (m / (m - one (V))) .* σ²new
259
+ nothing
260
+ end
261
+ Zygote. @nograd _track_stats!
262
+
224
263
"""
225
264
BatchNorm(channels::Integer, λ=identity;
226
265
initβ=zeros32, initγ=ones32,
@@ -234,15 +273,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
234
273
a batch of feature vectors this is just the data dimension, for `WHCN` images
235
274
it's the usual channel dimension.
236
275
237
- `BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
276
+ `BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
238
277
input slice and normalises the input accordingly.
239
278
240
- If `affine=true`, it also applies a shift and a rescale to the input
279
+ If `affine=true`, it also applies a shift and a rescale to the input
241
280
through to learnable per-channel bias β and scale γ parameters.
242
281
243
- After normalisation, elementwise activation `λ` is applied.
282
+ After normalisation, elementwise activation `λ` is applied.
244
283
245
- If `track_stats=true`, accumulates mean and var statistics in training phase
284
+ If `track_stats=true`, accumulates mean and var statistics in training phase
246
285
that will be used to renormalize the input in test phase.
247
286
248
287
Use [`testmode!`](@ref) during inference.
@@ -257,7 +296,7 @@ m = Chain(
257
296
softmax)
258
297
```
259
298
"""
260
- mutable struct BatchNorm{F,V,N,W}
299
+ mutable struct BatchNorm{F,V,N,W} <: Normalization{F, V, N, W}
261
300
λ:: F # activation function
262
301
β:: V # bias
263
302
γ:: V # scale
@@ -272,7 +311,7 @@ mutable struct BatchNorm{F,V,N,W}
272
311
end
273
312
274
313
function BatchNorm (chs:: Int , λ= identity;
275
- initβ= zeros32, initγ= ones32,
314
+ initβ= zeros32, initγ= ones32,
276
315
affine= true , track_stats= true ,
277
316
ϵ= 1f-5 , momentum= 0.1f0 )
278
317
@@ -282,8 +321,8 @@ function BatchNorm(chs::Int, λ=identity;
282
321
σ² = track_stats ? ones32 (chs) : nothing
283
322
284
323
return BatchNorm (λ, β, γ,
285
- μ, σ², ϵ, momentum,
286
- affine, track_stats,
324
+ μ, σ², ϵ, momentum,
325
+ affine, track_stats,
287
326
nothing , chs)
288
327
end
289
328
@@ -318,22 +357,22 @@ end
318
357
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
319
358
`channels` should be the size of the channel dimension in your data (see below).
320
359
321
- Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
360
+ Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
322
361
For `WHCN` images it's the usual channel dimension.
323
362
324
- `InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
363
+ `InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
325
364
input slice and normalises the input accordingly.
326
365
327
- If `affine=true`, it also applies a shift and a rescale to the input
366
+ If `affine=true`, it also applies a shift and a rescale to the input
328
367
through to learnable per-channel bias `β` and scale `γ` parameters.
329
368
330
- If `track_stats=true`, accumulates mean and var statistics in training phase
369
+ If `track_stats=true`, accumulates mean and var statistics in training phase
331
370
that will be used to renormalize the input in test phase.
332
371
333
- **Warning**: the defaults for `affine` and `track_stats` used to be `true`
372
+ **Warning**: the defaults for `affine` and `track_stats` used to be `true`
334
373
in previous Flux versions (< v0.12).
335
374
"""
336
- mutable struct InstanceNorm{F,V,N,W}
375
+ mutable struct InstanceNorm{F,V,N,W} <: Normalization{F, V, N, W}
337
376
λ:: F # activation function
338
377
β:: V # bias
339
378
γ:: V # scale
@@ -358,7 +397,7 @@ function InstanceNorm(chs::Int, λ=identity;
358
397
σ² = track_stats ? ones32 (chs) : nothing
359
398
360
399
return InstanceNorm (λ, β, γ,
361
- μ, σ², ϵ, momentum,
400
+ μ, σ², ϵ, momentum,
362
401
affine, track_stats,
363
402
nothing , chs)
364
403
end
@@ -401,16 +440,16 @@ The number of channels must be an integer multiple of the number of groups.
401
440
402
441
`channels` should be the size of the channel dimension in your data (see below).
403
442
404
- Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
443
+ Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
405
444
For `WHCN` images it's the usual channel dimension.
406
445
407
- If `affine=true`, it also applies a shift and a rescale to the input
446
+ If `affine=true`, it also applies a shift and a rescale to the input
408
447
through to learnable per-channel bias `β` and scale `γ` parameters.
409
448
410
- If `track_stats=true`, accumulates mean and var statistics in training phase
449
+ If `track_stats=true`, accumulates mean and var statistics in training phase
411
450
that will be used to renormalize the input in test phase.
412
451
"""
413
- mutable struct GroupNorm{F,V,N,W}
452
+ mutable struct GroupNorm{F,V,N,W} <: Normalization{F, V, N, W}
414
453
G:: Int # number of groups
415
454
λ:: F # activation function
416
455
β:: V # bias
429
468
trainable (gn:: GroupNorm ) = hasaffine (gn) ? (gn. β, gn. γ) : ()
430
469
431
470
function GroupNorm (chs:: Int , G:: Int , λ= identity;
432
- initβ= zeros32, initγ= ones32,
471
+ initβ= zeros32, initγ= ones32,
433
472
affine= true , track_stats= false ,
434
473
ϵ= 1f-5 , momentum= 0.1f0 )
435
474
@@ -440,11 +479,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
440
479
μ = track_stats ? zeros32 (G) : nothing
441
480
σ² = track_stats ? ones32 (G) : nothing
442
481
443
- return GroupNorm (G, λ,
482
+ return GroupNorm (G, λ,
444
483
β, γ,
445
- μ, σ²,
446
- ϵ, momentum,
447
- affine, track_stats,
484
+ μ, σ²,
485
+ ϵ, momentum,
486
+ affine, track_stats,
448
487
nothing , chs)
449
488
end
450
489
475
514
"""
476
515
hasaffine(l)
477
516
478
- Return `true` if a normalisation layer has trainable shift and
517
+ Return `true` if a normalisation layer has trainable shift and
479
518
scale parameters, `false` otherwise.
480
519
481
520
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
0 commit comments