@@ -195,26 +195,6 @@ _maybe_promote_type(::Type{T1}, ::Type{Nothing}) where T1 = T1
195
195
_maybe_eltype (:: Type{T} ) where T <: AbstractArray = eltype (T)
196
196
_maybe_eltype (:: Type{Nothing} ) = Nothing
197
197
198
- abstract type Normalization{F, V, N, W} end
199
-
200
- function _promote_to_output (
201
- :: Normalization{F, V, N, W} , x:: AbstractArray{T} ,
202
- ) where {F, V, N, W, T}
203
- Vel = _maybe_eltype (V)
204
- Wel = _maybe_eltype (W)
205
- _maybe_promote_type (_maybe_promote_type (
206
- _maybe_promote_type (T, Vel), N), Wel)
207
- end
208
-
209
- function _basetype (:: Type{T} ) where T
210
- if T <: Array
211
- return Array
212
- elseif T <: CuArray
213
- return CuArray
214
- end
215
- throw (" Unsupported type $T " )
216
- end
217
-
218
198
# For InstanceNorm, GroupNorm, and BatchNorm.
219
199
# Compute the statistics on the slices specified by reduce_dims.
220
200
# reduce_dims=[1,...,N-2,N] for BatchNorm
@@ -234,15 +214,16 @@ function _norm_layer_forward(
234
214
end
235
215
end
236
216
237
- O = _promote_to_output (l, x)
238
- o:: _basetype (typeof (x)){O, N} = ((x .- μ) ./ sqrt .(σ² .+ l. ϵ))
217
+ o = _norm_layer_forward (x, μ, σ², l. ϵ)
239
218
hasaffine (l) || return l. λ .(o)
240
219
241
220
γ = reshape (l. γ, affine_shape)
242
221
β = reshape (l. β, affine_shape)
243
222
return l. λ .(γ .* o .+ β)
244
223
end
245
224
225
+ @inline _norm_layer_forward (x, μ, σ², ϵ) = (x .- μ) ./ sqrt .(σ² .+ ϵ)
226
+
246
227
function _track_stats! (
247
228
bn, x:: AbstractArray{T, N} , μ, σ², reduce_dims,
248
229
) where {T, N}
@@ -256,7 +237,7 @@ function _track_stats!(
256
237
257
238
bn. μ = res_mtm .* bn. μ .+ mtm .* μnew
258
239
bn. σ² = res_mtm .* bn. σ² .+ mtm .* (m / (m - one (V))) .* σ²new
259
- nothing
240
+ return nothing
260
241
end
261
242
Zygote. @nograd _track_stats!
262
243
@@ -296,7 +277,7 @@ m = Chain(
296
277
softmax)
297
278
```
298
279
"""
299
- mutable struct BatchNorm{F,V,N,W} <: Normalization{F, V, N, W}
280
+ mutable struct BatchNorm{F,V,N,W}
300
281
λ:: F # activation function
301
282
β:: V # bias
302
283
γ:: V # scale
@@ -372,7 +353,7 @@ that will be used to renormalize the input in test phase.
372
353
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
373
354
in previous Flux versions (< v0.12).
374
355
"""
375
- mutable struct InstanceNorm{F,V,N,W} <: Normalization{F, V, N, W}
356
+ mutable struct InstanceNorm{F,V,N,W}
376
357
λ:: F # activation function
377
358
β:: V # bias
378
359
γ:: V # scale
@@ -449,7 +430,7 @@ through to learnable per-channel bias `β` and scale `γ` parameters.
449
430
If `track_stats=true`, accumulates mean and var statistics in training phase
450
431
that will be used to renormalize the input in test phase.
451
432
"""
452
- mutable struct GroupNorm{F,V,N,W} <: Normalization{F, V, N, W}
433
+ mutable struct GroupNorm{F,V,N,W}
453
434
G:: Int # number of groups
454
435
λ:: F # activation function
455
436
β:: V # bias
0 commit comments