Skip to content

Commit 5244ade

Browse files
authored
Merge pull request #1856 from pxl-th/master
Fix type-stability for normalization layers
2 parents 8d3b8d3 + d151080 commit 5244ade

File tree

3 files changed

+87
-49
lines changed

3 files changed

+87
-49
lines changed

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ end
664664

665665
function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
666666
stride = expand(Val(N), stride)
667-
pad = calc_padding(MaxPool ,pad, k, 1, stride)
667+
pad = calc_padding(MaxPool, pad, k, 1, stride)
668668
return MaxPool(k, pad, stride)
669669
end
670670

src/layers/normalise.jl

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,13 @@ testmode!(m::AlphaDropout, mode=true) =
146146
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
147147
148148
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
149-
used with recurrent hidden states.
150-
The argument `sz` should be an integer or a tuple of integers.
151-
In the forward pass, the layer normalises the mean and standard
149+
used with recurrent hidden states.
150+
The argument `sz` should be an integer or a tuple of integers.
151+
In the forward pass, the layer normalises the mean and standard
152152
deviation of the input, the applied the elementwise activation `λ`.
153153
The input is normalised along the first `length(sz)` dimensions
154154
for tuple `sz`, along the first dimension for integer `sz`.
155-
The input is expected to have first dimensions' size equal to `sz`.
155+
The input is expected to have first dimensions' size equal to `sz`.
156156
157157
If `affine=true` also applies a learnable shift and rescaling
158158
as in the [`Diagonal`](@ref) layer.
@@ -192,35 +192,48 @@ end
192192
# Compute the statistics on the slices specified by reduce_dims.
193193
# reduce_dims=[1,...,N-2,N] for BatchNorm
194194
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
195-
function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N}
195+
function _norm_layer_forward(
196+
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
197+
) where {T, N}
196198
if !_isactive(l) && l.track_stats # testmode with tracked stats
197199
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
198200
μ = reshape(l.μ, stats_shape)
199201
σ² = reshape(l.σ², stats_shape)
200-
else # trainmode or testmode without tracked stats
202+
else # trainmode or testmode without tracked stats
201203
μ = mean(x; dims=reduce_dims)
202204
σ² = mean((x .- μ).^2; dims=reduce_dims)
203205
if l.track_stats
204-
## update moving mean/std
205-
Zygote.ignore() do
206-
mtm = l.momentum
207-
m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var
208-
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
209-
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))
210-
l.μ = (1-mtm) .* l.μ .+ mtm .* μnew
211-
l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new
212-
end
206+
_track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std
213207
end
214208
end
215-
if hasaffine(l)
216-
γ = reshape(l.γ, affine_shape)
217-
β = reshape(l.β, affine_shape)
218-
return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β)
219-
else
220-
return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
221-
end
209+
210+
o = _norm_layer_forward(x, μ, σ², l.ϵ)
211+
hasaffine(l) || return l.λ.(o)
212+
213+
γ = reshape(l.γ, affine_shape)
214+
β = reshape(l.β, affine_shape)
215+
return l.λ.(γ .* o .+ β)
222216
end
223217

218+
@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ)
219+
220+
function _track_stats!(
221+
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
222+
) where {T, N}
223+
V = eltype(bn.σ²)
224+
mtm = bn.momentum
225+
res_mtm = one(V) - mtm
226+
m = prod(size(x, i) for i in reduce_dims)
227+
228+
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
229+
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))
230+
231+
bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
232+
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
233+
return nothing
234+
end
235+
Zygote.@nograd _track_stats!
236+
224237
"""
225238
BatchNorm(channels::Integer, λ=identity;
226239
initβ=zeros32, initγ=ones32,
@@ -234,15 +247,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
234247
a batch of feature vectors this is just the data dimension, for `WHCN` images
235248
it's the usual channel dimension.
236249
237-
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
250+
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
238251
input slice and normalises the input accordingly.
239252
240-
If `affine=true`, it also applies a shift and a rescale to the input
253+
If `affine=true`, it also applies a shift and a rescale to the input
241254
through to learnable per-channel bias β and scale γ parameters.
242255
243-
After normalisation, elementwise activation `λ` is applied.
256+
After normalisation, elementwise activation `λ` is applied.
244257
245-
If `track_stats=true`, accumulates mean and var statistics in training phase
258+
If `track_stats=true`, accumulates mean and var statistics in training phase
246259
that will be used to renormalize the input in test phase.
247260
248261
Use [`testmode!`](@ref) during inference.
@@ -272,7 +285,7 @@ mutable struct BatchNorm{F,V,N,W}
272285
end
273286

274287
function BatchNorm(chs::Int, λ=identity;
275-
initβ=zeros32, initγ=ones32,
288+
initβ=zeros32, initγ=ones32,
276289
affine=true, track_stats=true,
277290
ϵ=1f-5, momentum=0.1f0)
278291

@@ -282,8 +295,8 @@ function BatchNorm(chs::Int, λ=identity;
282295
σ² = track_stats ? ones32(chs) : nothing
283296

284297
return BatchNorm(λ, β, γ,
285-
μ, σ², ϵ, momentum,
286-
affine, track_stats,
298+
μ, σ², ϵ, momentum,
299+
affine, track_stats,
287300
nothing, chs)
288301
end
289302

@@ -318,19 +331,19 @@ end
318331
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
319332
`channels` should be the size of the channel dimension in your data (see below).
320333
321-
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
334+
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
322335
For `WHCN` images it's the usual channel dimension.
323336
324-
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
337+
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
325338
input slice and normalises the input accordingly.
326339
327-
If `affine=true`, it also applies a shift and a rescale to the input
340+
If `affine=true`, it also applies a shift and a rescale to the input
328341
through to learnable per-channel bias `β` and scale `γ` parameters.
329342
330-
If `track_stats=true`, accumulates mean and var statistics in training phase
343+
If `track_stats=true`, accumulates mean and var statistics in training phase
331344
that will be used to renormalize the input in test phase.
332345
333-
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
346+
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
334347
in previous Flux versions (< v0.12).
335348
"""
336349
mutable struct InstanceNorm{F,V,N,W}
@@ -358,7 +371,7 @@ function InstanceNorm(chs::Int, λ=identity;
358371
σ² = track_stats ? ones32(chs) : nothing
359372

360373
return InstanceNorm(λ, β, γ,
361-
μ, σ², ϵ, momentum,
374+
μ, σ², ϵ, momentum,
362375
affine, track_stats,
363376
nothing, chs)
364377
end
@@ -401,13 +414,13 @@ The number of channels must be an integer multiple of the number of groups.
401414
402415
`channels` should be the size of the channel dimension in your data (see below).
403416
404-
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
417+
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
405418
For `WHCN` images it's the usual channel dimension.
406419
407-
If `affine=true`, it also applies a shift and a rescale to the input
420+
If `affine=true`, it also applies a shift and a rescale to the input
408421
through to learnable per-channel bias `β` and scale `γ` parameters.
409422
410-
If `track_stats=true`, accumulates mean and var statistics in training phase
423+
If `track_stats=true`, accumulates mean and var statistics in training phase
411424
that will be used to renormalize the input in test phase.
412425
"""
413426
mutable struct GroupNorm{F,V,N,W}
@@ -429,7 +442,7 @@ end
429442
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()
430443

431444
function GroupNorm(chs::Int, G::Int, λ=identity;
432-
initβ=zeros32, initγ=ones32,
445+
initβ=zeros32, initγ=ones32,
433446
affine=true, track_stats=false,
434447
ϵ=1f-5, momentum=0.1f0)
435448

@@ -440,11 +453,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
440453
μ = track_stats ? zeros32(G) : nothing
441454
σ² = track_stats ? ones32(G) : nothing
442455

443-
return GroupNorm(G, λ,
456+
return GroupNorm(G, λ,
444457
β, γ,
445-
μ, σ²,
446-
ϵ, momentum,
447-
affine, track_stats,
458+
μ, σ²,
459+
ϵ, momentum,
460+
affine, track_stats,
448461
nothing, chs)
449462
end
450463

@@ -475,7 +488,7 @@ end
475488
"""
476489
hasaffine(l)
477490
478-
Return `true` if a normalisation layer has trainable shift and
491+
Return `true` if a normalisation layer has trainable shift and
479492
scale parameters, `false` otherwise.
480493
481494
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).

test/layers/normalisation.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,39 +149,50 @@ end
149149
# 1.3
150150
# 1.3
151151
@test m.σ² .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
152-
152+
153153
x′ = m(x)
154154
@test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
155+
156+
@inferred m(x)
157+
end
158+
159+
let m = BatchNorm(2; track_stats=false), x = [1.0 3.0 5.0; 2.0 4.0 6.0]
160+
@inferred m(x)
155161
end
156162

157163
# with activation function
158164
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
159165
2.0 4.0 6.0]
160166
y = m(x)
161167
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
168+
@inferred m(x)
162169
end
163170

164171
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
165172
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
166173
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
167174
@test m(x) == y
175+
@inferred m(x)
168176
end
169177

170178
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
171179
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
172180
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
173181
@test m(x) == y
182+
@inferred m(x)
174183
end
175184

176185
let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
177186
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
178187
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
179188
@test m(x) == y
189+
@inferred m(x)
180190
end
181191

182192
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
183193
m(x)
184194
@test (@allocated m(x)) < 100_000_000
195+
@inferred m(x)
185196
end
186197

187198
@test length(Flux.params(BatchNorm(10))) == 2
@@ -232,6 +243,8 @@ end
232243
@test length(m.μ) == 2
233244
@test length(m.σ²) == 2
234245
@test y (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5
246+
247+
@inferred m(x)
235248
end
236249

237250
# with activation function
@@ -242,35 +255,41 @@ end
242255
affine_shape[[1,3]] .= 1
243256

244257
y = evalwgrad(m, x)
245-
y = m(x) # inference time after a training step
258+
y = m(x) # inference time after a training step
246259
μ = reshape(m.μ, affine_shape...)
247260
σ² = reshape(m.σ², affine_shape...)
248261
@test y sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7
262+
263+
@inferred m(x)
249264
end
250265

251266
# with activation function
252267
let m = InstanceNorm(2, sigmoid; affine=true, track_stats=false), sizes = (3, 2, 2),
253268
x = reshape(collect(1:prod(sizes)), sizes)
254269

255270
@test Flux.hasaffine(m) == true
256-
@test length(params(m)) == 2
271+
@test length(params(m)) == 2
257272
x = Float64.(x)
258273
y = m(x)
259274
μ = mean(x, dims=1)
260275
σ² = var(x, dims=1, corrected=false)
261276
@test y sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7
277+
278+
@inferred m(x)
262279
end
263280

264281
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
265282
x = reshape(collect(1:prod(sizes)), sizes)
266283
@test Flux.hasaffine(m) == false
267284
@test length(params(m)) == 0
268-
285+
269286
x = Float64.(x)
270287
y = m(x)
271288
μ = mean(x, dims=1)
272289
σ² = var(x, dims=1, corrected=false)
273290
@test y sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7
291+
292+
@inferred m(x)
274293
end
275294

276295

@@ -279,6 +298,8 @@ end
279298
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
280299
y = reshape(m(y), sizes...)
281300
@test m(x) == y
301+
302+
@inferred m(x)
282303
end
283304

284305
# check that μ, σ², and the output are the correct size for higher rank tensors
@@ -288,6 +309,8 @@ end
288309
@test size(m.μ) == (sizes[end - 1], )
289310
@test size(m.σ²) == (sizes[end - 1], )
290311
@test size(y) == sizes
312+
313+
@inferred m(x)
291314
end
292315

293316
# show that instance norm is equal to batch norm when channel and batch dims are squashed
@@ -299,6 +322,8 @@ end
299322
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
300323
m(x)
301324
@test (@allocated m(x)) < 100_000_000
325+
326+
@inferred m(x)
302327
end
303328

304329
@test length(Flux.params(InstanceNorm(10))) == 0

0 commit comments

Comments
 (0)