Skip to content

Commit 128226d

Browse files
author
Anton Smirnov
committed
Fix type-stability for normalization layers
1 parent 7f375aa commit 128226d

File tree

5 files changed

+166
-103
lines changed

5 files changed

+166
-103
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ Colors = "0.12"
3838
Functors = "0.2.1"
3939
Juno = "0.8"
4040
MacroTools = "0.5"
41-
NNlib = "0.7.24"
42-
NNlibCUDA = "0.1.7"
41+
NNlib = "0.8.0"
42+
NNlibCUDA = "0.2.0"
4343
Reexport = "0.2, 1.0"
4444
StatsBase = "0.33"
4545
ZipFile = "0.9"

src/layers/conv.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ end
162162

163163
function (c::Conv)(x::AbstractArray)
164164
σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1)
165-
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
165+
cdims = DenseConvDims(
166+
x, c.weight; stride=c.stride, padding=c.pad,
167+
dilation=c.dilation, groups=c.groups)
166168
σ.(conv(x, c.weight, cdims) .+ b)
167169
end
168170

@@ -656,19 +658,23 @@ julia> lay(rand(Float32, 100, 7, 50)) |> size
656658
(34, 7, 50)
657659
```
658660
"""
659-
struct MaxPool{N,M}
661+
struct MaxPool{N, M}
660662
k::NTuple{N,Int}
661663
pad::NTuple{M,Int}
662664
stride::NTuple{N,Int}
663665
end
664666

665-
function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
667+
function MaxPool(k::NTuple{N, Integer}; pad = 0, stride = k) where N
666668
stride = expand(Val(N), stride)
667-
pad = calc_padding(MaxPool ,pad, k, 1, stride)
669+
pad = calc_padding(MaxPool, pad, k, 1, stride)
668670
return MaxPool(k, pad, stride)
669671
end
670672

671673
function (m::MaxPool)(x)
674+
# size_x = size(x)
675+
# kernel, stride, padding, dilation = NNlib.prepare_pooldims(
676+
# Val(N), size_x, m.k; padding=m.pad, stride=m.stride)
677+
# pdims = PoolDims{kernel, stride, padding, dilation}(size_x)
672678
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
673679
return maxpool(x, pdims)
674680
end

src/layers/normalise.jl

Lines changed: 79 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,13 @@ testmode!(m::AlphaDropout, mode=true) =
122122
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
123123
124124
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
125-
used with recurrent hidden states.
126-
The argument `sz` should be an integer or a tuple of integers.
127-
In the forward pass, the layer normalises the mean and standard
125+
used with recurrent hidden states.
126+
The argument `sz` should be an integer or a tuple of integers.
127+
In the forward pass, the layer normalises the mean and standard
128128
deviation of the input, the applied the elementwise activation `λ`.
129129
The input is normalised along the first `length(sz)` dimensions
130130
for tuple `sz`, along the first dimension for integer `sz`.
131-
The input is expected to have first dimensions' size equal to `sz`.
131+
The input is expected to have first dimensions' size equal to `sz`.
132132
133133
If `affine=true` also applies a learnable shift and rescaling
134134
as in the [`Diagonal`](@ref) layer.
@@ -164,38 +164,70 @@ function Base.show(io::IO, l::LayerNorm)
164164
print(io, ")")
165165
end
166166

167+
_maybe_promote_type(::Type{T1}, ::Type{T2}) where {T1, T2} = promote_type(T1, T2)
168+
_maybe_promote_type(::Type{Nothing}, ::Type{T2}) where T2 = T2
169+
_maybe_promote_type(::Type{T1}, ::Type{Nothing}) where T1 = T1
170+
171+
_maybe_eltype(::Type{T}) where T <: AbstractArray = eltype(T)
172+
_maybe_eltype(::Type{Nothing}) = Nothing
173+
174+
abstract type Normalization{F, V, N, W} end
175+
176+
function _promote_to_output(
177+
::Normalization{F, V, N, W}, x::AbstractArray{T},
178+
) where {F, V, N, W, T}
179+
Vel = _maybe_eltype(V)
180+
Wel = _maybe_eltype(W)
181+
_maybe_promote_type(_maybe_promote_type(
182+
_maybe_promote_type(T, Vel), N), Wel)
183+
end
184+
167185
# For InstanceNorm, GroupNorm, and BatchNorm.
168186
# Compute the statistics on the slices specified by reduce_dims.
169187
# reduce_dims=[1,...,N-2,N] for BatchNorm
170188
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
171-
function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N}
189+
_norm_layer_forward(l, x; reduce_dims, affine_shape) =
190+
_norm_layer_forward(l, x, _promote_to_output(l, x); reduce_dims, affine_shape)
191+
192+
function _norm_layer_forward(
193+
l, x::Array{T, N}, ::Type{O}; reduce_dims, affine_shape,
194+
) where {T, N, O}
172195
if !_isactive(l) && l.track_stats # testmode with tracked stats
173196
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
174197
μ = reshape(l.μ, stats_shape)
175198
σ² = reshape(l.σ², stats_shape)
176-
else # trainmode or testmode without tracked stats
199+
else # trainmode or testmode without tracked stats
177200
μ = mean(x; dims=reduce_dims)
178201
σ² = mean((x .- μ).^2; dims=reduce_dims)
179202
if l.track_stats
180-
## update moving mean/std
181-
Zygote.ignore() do
182-
mtm = l.momentum
183-
m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var
184-
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
185-
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))
186-
l.μ = (1-mtm) .* l.μ .+ mtm .* μnew
187-
l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new
188-
end
203+
_track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std
189204
end
190205
end
191-
if hasaffine(l)
192-
γ = reshape(l.γ, affine_shape)
193-
β = reshape(l.β, affine_shape)
194-
return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β)
195-
else
196-
return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
197-
end
206+
207+
o::Array{O, N} = ((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
208+
hasaffine(l) || return l.λ.(o)
209+
210+
γ = reshape(l.γ, affine_shape)
211+
β = reshape(l.β, affine_shape)
212+
return l.λ.(γ .* o .+ β)
213+
end
214+
215+
function _track_stats!(
216+
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
217+
) where {T, N}
218+
V = eltype(bn.σ²)
219+
mtm = bn.momentum
220+
res_mtm = one(V) - mtm
221+
m = prod(size(x, i) for i in reduce_dims)
222+
223+
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
224+
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))
225+
226+
bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
227+
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
228+
nothing
198229
end
230+
Zygote.@nograd _track_stats!
199231

200232
"""
201233
BatchNorm(channels::Integer, λ=identity;
@@ -210,15 +242,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
210242
a batch of feature vectors this is just the data dimension, for `WHCN` images
211243
it's the usual channel dimension.
212244
213-
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
245+
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
214246
input slice and normalises the input accordingly.
215247
216-
If `affine=true`, it also applies a shift and a rescale to the input
248+
If `affine=true`, it also applies a shift and a rescale to the input
217249
through to learnable per-channel bias β and scale γ parameters.
218250
219-
After normalisation, elementwise activation `λ` is applied.
251+
After normalisation, elementwise activation `λ` is applied.
220252
221-
If `track_stats=true`, accumulates mean and var statistics in training phase
253+
If `track_stats=true`, accumulates mean and var statistics in training phase
222254
that will be used to renormalize the input in test phase.
223255
224256
Use [`testmode!`](@ref) during inference.
@@ -233,7 +265,7 @@ m = Chain(
233265
softmax)
234266
```
235267
"""
236-
mutable struct BatchNorm{F,V,N,W}
268+
mutable struct BatchNorm{F,V,N,W} <: Normalization{F, V, N, W}
237269
λ::F # activation function
238270
β::V # bias
239271
γ::V # scale
@@ -248,7 +280,7 @@ mutable struct BatchNorm{F,V,N,W}
248280
end
249281

250282
function BatchNorm(chs::Int, λ=identity;
251-
initβ=zeros32, initγ=ones32,
283+
initβ=zeros32, initγ=ones32,
252284
affine=true, track_stats=true,
253285
ϵ=1f-5, momentum=0.1f0)
254286

@@ -258,8 +290,8 @@ function BatchNorm(chs::Int, λ=identity;
258290
σ² = track_stats ? ones32(chs) : nothing
259291

260292
return BatchNorm(λ, β, γ,
261-
μ, σ², ϵ, momentum,
262-
affine, track_stats,
293+
μ, σ², ϵ, momentum,
294+
affine, track_stats,
263295
nothing, chs)
264296
end
265297

@@ -294,22 +326,22 @@ end
294326
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
295327
`channels` should be the size of the channel dimension in your data (see below).
296328
297-
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
329+
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
298330
For `WHCN` images it's the usual channel dimension.
299331
300-
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
332+
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
301333
input slice and normalises the input accordingly.
302334
303-
If `affine=true`, it also applies a shift and a rescale to the input
335+
If `affine=true`, it also applies a shift and a rescale to the input
304336
through to learnable per-channel bias `β` and scale `γ` parameters.
305337
306-
If `track_stats=true`, accumulates mean and var statistics in training phase
338+
If `track_stats=true`, accumulates mean and var statistics in training phase
307339
that will be used to renormalize the input in test phase.
308340
309-
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
341+
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
310342
in previous Flux versions (< v0.12).
311343
"""
312-
mutable struct InstanceNorm{F,V,N,W}
344+
mutable struct InstanceNorm{F,V,N,W} <: Normalization{F, V, N, W}
313345
λ::F # activation function
314346
β::V # bias
315347
γ::V # scale
@@ -334,7 +366,7 @@ function InstanceNorm(chs::Int, λ=identity;
334366
σ² = track_stats ? ones32(chs) : nothing
335367

336368
return InstanceNorm(λ, β, γ,
337-
μ, σ², ϵ, momentum,
369+
μ, σ², ϵ, momentum,
338370
affine, track_stats,
339371
nothing, chs)
340372
end
@@ -377,16 +409,16 @@ The number of channels must be an integer multiple of the number of groups.
377409
378410
`channels` should be the size of the channel dimension in your data (see below).
379411
380-
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
412+
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
381413
For `WHCN` images it's the usual channel dimension.
382414
383-
If `affine=true`, it also applies a shift and a rescale to the input
415+
If `affine=true`, it also applies a shift and a rescale to the input
384416
through to learnable per-channel bias `β` and scale `γ` parameters.
385417
386-
If `track_stats=true`, accumulates mean and var statistics in training phase
418+
If `track_stats=true`, accumulates mean and var statistics in training phase
387419
that will be used to renormalize the input in test phase.
388420
"""
389-
mutable struct GroupNorm{F,V,N,W}
421+
mutable struct GroupNorm{F,V,N,W} <: Normalization{F, V, N, W}
390422
G::Int # number of groups
391423
λ::F # activation function
392424
β::V # bias
@@ -405,7 +437,7 @@ end
405437
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()
406438

407439
function GroupNorm(chs::Int, G::Int, λ=identity;
408-
initβ=zeros32, initγ=ones32,
440+
initβ=zeros32, initγ=ones32,
409441
affine=true, track_stats=false,
410442
ϵ=1f-5, momentum=0.1f0)
411443

@@ -416,11 +448,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
416448
μ = track_stats ? zeros32(G) : nothing
417449
σ² = track_stats ? ones32(G) : nothing
418450

419-
return GroupNorm(G, λ,
451+
return GroupNorm(G, λ,
420452
β, γ,
421-
μ, σ²,
422-
ϵ, momentum,
423-
affine, track_stats,
453+
μ, σ²,
454+
ϵ, momentum,
455+
affine, track_stats,
424456
nothing, chs)
425457
end
426458

@@ -451,7 +483,7 @@ end
451483
"""
452484
hasaffine(l)
453485
454-
Return `true` if a normalisation layer has trainable shift and
486+
Return `true` if a normalisation layer has trainable shift and
455487
scale parameters, `false` otherwise.
456488
457489
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).

0 commit comments

Comments
 (0)