Skip to content

Commit 9e7cef2

Browse files
author
Anton Smirnov
authored
Fix type-stability for normalization layers
2 parents 8d3b8d3 + 4a3c80b commit 9e7cef2

File tree

3 files changed

+125
-55
lines changed

3 files changed

+125
-55
lines changed

src/layers/conv.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ end
162162

163163
function (c::Conv)(x::AbstractArray)
164164
σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1)
165-
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
165+
cdims = DenseConvDims(
166+
x, c.weight; stride=c.stride, padding=c.pad,
167+
dilation=c.dilation, groups=c.groups)
166168
σ.(conv(x, c.weight, cdims) .+ b)
167169
end
168170

@@ -656,19 +658,23 @@ julia> lay(rand(Float32, 100, 7, 50)) |> size
656658
(34, 7, 50)
657659
```
658660
"""
659-
struct MaxPool{N,M}
661+
struct MaxPool{N, M}
660662
k::NTuple{N,Int}
661663
pad::NTuple{M,Int}
662664
stride::NTuple{N,Int}
663665
end
664666

665-
function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
667+
function MaxPool(k::NTuple{N, Integer}; pad = 0, stride = k) where N
666668
stride = expand(Val(N), stride)
667-
pad = calc_padding(MaxPool ,pad, k, 1, stride)
669+
pad = calc_padding(MaxPool, pad, k, 1, stride)
668670
return MaxPool(k, pad, stride)
669671
end
670672

671673
function (m::MaxPool)(x)
674+
# size_x = size(x)
675+
# kernel, stride, padding, dilation = NNlib.prepare_pooldims(
676+
# Val(N), size_x, m.k; padding=m.pad, stride=m.stride)
677+
# pdims = PoolDims{kernel, stride, padding, dilation}(size_x)
672678
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
673679
return maxpool(x, pdims)
674680
end

src/layers/normalise.jl

Lines changed: 86 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,13 @@ testmode!(m::AlphaDropout, mode=true) =
146146
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
147147
148148
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
149-
used with recurrent hidden states.
150-
The argument `sz` should be an integer or a tuple of integers.
151-
In the forward pass, the layer normalises the mean and standard
149+
used with recurrent hidden states.
150+
The argument `sz` should be an integer or a tuple of integers.
151+
In the forward pass, the layer normalises the mean and standard
152152
deviation of the input, the applied the elementwise activation `λ`.
153153
The input is normalised along the first `length(sz)` dimensions
154154
for tuple `sz`, along the first dimension for integer `sz`.
155-
The input is expected to have first dimensions' size equal to `sz`.
155+
The input is expected to have first dimensions' size equal to `sz`.
156156
157157
If `affine=true` also applies a learnable shift and rescaling
158158
as in the [`Diagonal`](@ref) layer.
@@ -188,39 +188,78 @@ function Base.show(io::IO, l::LayerNorm)
188188
print(io, ")")
189189
end
190190

191+
_maybe_promote_type(::Type{T1}, ::Type{T2}) where {T1, T2} = promote_type(T1, T2)
192+
_maybe_promote_type(::Type{Nothing}, ::Type{T2}) where T2 = T2
193+
_maybe_promote_type(::Type{T1}, ::Type{Nothing}) where T1 = T1
194+
195+
_maybe_eltype(::Type{T}) where T <: AbstractArray = eltype(T)
196+
_maybe_eltype(::Type{Nothing}) = Nothing
197+
198+
abstract type Normalization{F, V, N, W} end
199+
200+
function _promote_to_output(
201+
::Normalization{F, V, N, W}, x::AbstractArray{T},
202+
) where {F, V, N, W, T}
203+
Vel = _maybe_eltype(V)
204+
Wel = _maybe_eltype(W)
205+
_maybe_promote_type(_maybe_promote_type(
206+
_maybe_promote_type(T, Vel), N), Wel)
207+
end
208+
209+
function _basetype(::Type{T}) where T
210+
if T <: Array
211+
return Array
212+
elseif T <: CuArray
213+
return CuArray
214+
end
215+
throw("Unsupported type $T")
216+
end
217+
191218
# For InstanceNorm, GroupNorm, and BatchNorm.
192219
# Compute the statistics on the slices specified by reduce_dims.
193220
# reduce_dims=[1,...,N-2,N] for BatchNorm
194221
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
195-
function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N}
222+
function _norm_layer_forward(
223+
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
224+
) where {T, N}
196225
if !_isactive(l) && l.track_stats # testmode with tracked stats
197226
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
198227
μ = reshape(l.μ, stats_shape)
199228
σ² = reshape(l.σ², stats_shape)
200-
else # trainmode or testmode without tracked stats
229+
else # trainmode or testmode without tracked stats
201230
μ = mean(x; dims=reduce_dims)
202231
σ² = mean((x .- μ).^2; dims=reduce_dims)
203232
if l.track_stats
204-
## update moving mean/std
205-
Zygote.ignore() do
206-
mtm = l.momentum
207-
m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var
208-
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
209-
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))
210-
l.μ = (1-mtm) .* l.μ .+ mtm .* μnew
211-
l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new
212-
end
233+
_track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std
213234
end
214235
end
215-
if hasaffine(l)
216-
γ = reshape(l.γ, affine_shape)
217-
β = reshape(l.β, affine_shape)
218-
return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β)
219-
else
220-
return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
221-
end
236+
237+
O = _promote_to_output(l, x)
238+
o::_basetype(typeof(x)){O, N} = ((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
239+
hasaffine(l) || return l.λ.(o)
240+
241+
γ = reshape(l.γ, affine_shape)
242+
β = reshape(l.β, affine_shape)
243+
return l.λ.(γ .* o .+ β)
222244
end
223245

246+
function _track_stats!(
247+
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
248+
) where {T, N}
249+
V = eltype(bn.σ²)
250+
mtm = bn.momentum
251+
res_mtm = one(V) - mtm
252+
m = prod(size(x, i) for i in reduce_dims)
253+
254+
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
255+
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))
256+
257+
bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
258+
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
259+
nothing
260+
end
261+
Zygote.@nograd _track_stats!
262+
224263
"""
225264
BatchNorm(channels::Integer, λ=identity;
226265
initβ=zeros32, initγ=ones32,
@@ -234,15 +273,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
234273
a batch of feature vectors this is just the data dimension, for `WHCN` images
235274
it's the usual channel dimension.
236275
237-
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
276+
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
238277
input slice and normalises the input accordingly.
239278
240-
If `affine=true`, it also applies a shift and a rescale to the input
279+
If `affine=true`, it also applies a shift and a rescale to the input
241280
through to learnable per-channel bias β and scale γ parameters.
242281
243-
After normalisation, elementwise activation `λ` is applied.
282+
After normalisation, elementwise activation `λ` is applied.
244283
245-
If `track_stats=true`, accumulates mean and var statistics in training phase
284+
If `track_stats=true`, accumulates mean and var statistics in training phase
246285
that will be used to renormalize the input in test phase.
247286
248287
Use [`testmode!`](@ref) during inference.
@@ -257,7 +296,7 @@ m = Chain(
257296
softmax)
258297
```
259298
"""
260-
mutable struct BatchNorm{F,V,N,W}
299+
mutable struct BatchNorm{F,V,N,W} <: Normalization{F, V, N, W}
261300
λ::F # activation function
262301
β::V # bias
263302
γ::V # scale
@@ -272,7 +311,7 @@ mutable struct BatchNorm{F,V,N,W}
272311
end
273312

274313
function BatchNorm(chs::Int, λ=identity;
275-
initβ=zeros32, initγ=ones32,
314+
initβ=zeros32, initγ=ones32,
276315
affine=true, track_stats=true,
277316
ϵ=1f-5, momentum=0.1f0)
278317

@@ -282,8 +321,8 @@ function BatchNorm(chs::Int, λ=identity;
282321
σ² = track_stats ? ones32(chs) : nothing
283322

284323
return BatchNorm(λ, β, γ,
285-
μ, σ², ϵ, momentum,
286-
affine, track_stats,
324+
μ, σ², ϵ, momentum,
325+
affine, track_stats,
287326
nothing, chs)
288327
end
289328

@@ -318,22 +357,22 @@ end
318357
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
319358
`channels` should be the size of the channel dimension in your data (see below).
320359
321-
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
360+
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
322361
For `WHCN` images it's the usual channel dimension.
323362
324-
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
363+
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
325364
input slice and normalises the input accordingly.
326365
327-
If `affine=true`, it also applies a shift and a rescale to the input
366+
If `affine=true`, it also applies a shift and a rescale to the input
328367
through to learnable per-channel bias `β` and scale `γ` parameters.
329368
330-
If `track_stats=true`, accumulates mean and var statistics in training phase
369+
If `track_stats=true`, accumulates mean and var statistics in training phase
331370
that will be used to renormalize the input in test phase.
332371
333-
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
372+
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
334373
in previous Flux versions (< v0.12).
335374
"""
336-
mutable struct InstanceNorm{F,V,N,W}
375+
mutable struct InstanceNorm{F,V,N,W} <: Normalization{F, V, N, W}
337376
λ::F # activation function
338377
β::V # bias
339378
γ::V # scale
@@ -358,7 +397,7 @@ function InstanceNorm(chs::Int, λ=identity;
358397
σ² = track_stats ? ones32(chs) : nothing
359398

360399
return InstanceNorm(λ, β, γ,
361-
μ, σ², ϵ, momentum,
400+
μ, σ², ϵ, momentum,
362401
affine, track_stats,
363402
nothing, chs)
364403
end
@@ -401,16 +440,16 @@ The number of channels must be an integer multiple of the number of groups.
401440
402441
`channels` should be the size of the channel dimension in your data (see below).
403442
404-
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
443+
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
405444
For `WHCN` images it's the usual channel dimension.
406445
407-
If `affine=true`, it also applies a shift and a rescale to the input
446+
If `affine=true`, it also applies a shift and a rescale to the input
408447
through to learnable per-channel bias `β` and scale `γ` parameters.
409448
410-
If `track_stats=true`, accumulates mean and var statistics in training phase
449+
If `track_stats=true`, accumulates mean and var statistics in training phase
411450
that will be used to renormalize the input in test phase.
412451
"""
413-
mutable struct GroupNorm{F,V,N,W}
452+
mutable struct GroupNorm{F,V,N,W} <: Normalization{F, V, N, W}
414453
G::Int # number of groups
415454
λ::F # activation function
416455
β::V # bias
@@ -429,7 +468,7 @@ end
429468
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()
430469

431470
function GroupNorm(chs::Int, G::Int, λ=identity;
432-
initβ=zeros32, initγ=ones32,
471+
initβ=zeros32, initγ=ones32,
433472
affine=true, track_stats=false,
434473
ϵ=1f-5, momentum=0.1f0)
435474

@@ -440,11 +479,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
440479
μ = track_stats ? zeros32(G) : nothing
441480
σ² = track_stats ? ones32(G) : nothing
442481

443-
return GroupNorm(G, λ,
482+
return GroupNorm(G, λ,
444483
β, γ,
445-
μ, σ²,
446-
ϵ, momentum,
447-
affine, track_stats,
484+
μ, σ²,
485+
ϵ, momentum,
486+
affine, track_stats,
448487
nothing, chs)
449488
end
450489

@@ -475,7 +514,7 @@ end
475514
"""
476515
hasaffine(l)
477516
478-
Return `true` if a normalisation layer has trainable shift and
517+
Return `true` if a normalisation layer has trainable shift and
479518
scale parameters, `false` otherwise.
480519
481520
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).

0 commit comments

Comments
 (0)