Skip to content

Commit ea50787

Browse files
author
Anton Smirnov
committed
Use inference barrier
1 parent c85ec1e commit ea50787

File tree

1 file changed

+7
-26
lines changed

1 file changed

+7
-26
lines changed

src/layers/normalise.jl

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -195,26 +195,6 @@ _maybe_promote_type(::Type{T1}, ::Type{Nothing}) where T1 = T1
195195
_maybe_eltype(::Type{T}) where T <: AbstractArray = eltype(T)
196196
_maybe_eltype(::Type{Nothing}) = Nothing
197197

198-
abstract type Normalization{F, V, N, W} end
199-
200-
function _promote_to_output(
201-
::Normalization{F, V, N, W}, x::AbstractArray{T},
202-
) where {F, V, N, W, T}
203-
Vel = _maybe_eltype(V)
204-
Wel = _maybe_eltype(W)
205-
_maybe_promote_type(_maybe_promote_type(
206-
_maybe_promote_type(T, Vel), N), Wel)
207-
end
208-
209-
function _basetype(::Type{T}) where T
210-
if T <: Array
211-
return Array
212-
elseif T <: CuArray
213-
return CuArray
214-
end
215-
throw("Unsupported type $T")
216-
end
217-
218198
# For InstanceNorm, GroupNorm, and BatchNorm.
219199
# Compute the statistics on the slices specified by reduce_dims.
220200
# reduce_dims=[1,...,N-2,N] for BatchNorm
@@ -234,15 +214,16 @@ function _norm_layer_forward(
234214
end
235215
end
236216

237-
O = _promote_to_output(l, x)
238-
o::_basetype(typeof(x)){O, N} = ((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
217+
o = _norm_layer_forward(x, μ, σ², l.ϵ)
239218
hasaffine(l) || return l.λ.(o)
240219

241220
γ = reshape(l.γ, affine_shape)
242221
β = reshape(l.β, affine_shape)
243222
return l.λ.(γ .* o .+ β)
244223
end
245224

225+
@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ)
226+
246227
function _track_stats!(
247228
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
248229
) where {T, N}
@@ -256,7 +237,7 @@ function _track_stats!(
256237

257238
bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
258239
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
259-
nothing
240+
return nothing
260241
end
261242
Zygote.@nograd _track_stats!
262243

@@ -296,7 +277,7 @@ m = Chain(
296277
softmax)
297278
```
298279
"""
299-
mutable struct BatchNorm{F,V,N,W} <: Normalization{F, V, N, W}
280+
mutable struct BatchNorm{F,V,N,W}
300281
λ::F # activation function
301282
β::V # bias
302283
γ::V # scale
@@ -372,7 +353,7 @@ that will be used to renormalize the input in test phase.
372353
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
373354
in previous Flux versions (< v0.12).
374355
"""
375-
mutable struct InstanceNorm{F,V,N,W} <: Normalization{F, V, N, W}
356+
mutable struct InstanceNorm{F,V,N,W}
376357
λ::F # activation function
377358
β::V # bias
378359
γ::V # scale
@@ -449,7 +430,7 @@ through to learnable per-channel bias `β` and scale `γ` parameters.
449430
If `track_stats=true`, accumulates mean and var statistics in training phase
450431
that will be used to renormalize the input in test phase.
451432
"""
452-
mutable struct GroupNorm{F,V,N,W} <: Normalization{F, V, N, W}
433+
mutable struct GroupNorm{F,V,N,W}
453434
G::Int # number of groups
454435
λ::F # activation function
455436
β::V # bias

0 commit comments

Comments
 (0)