Skip to content

Commit 7e9a180

Browse files
Merge #1397
1397: Rework normalization layers r=DhairyaLGandhi a=CarloLucibello This goes in the direction of creating a powerful and consistent experience for the normalization layers. - add an activation function to LayerNorm, consistently with other normalization layers - generalize LayerNorm and Diagonal to handle multiple dimensions - add the bool keyword `affine` to LayerNorm, GroupNorm, InstanceNorm, and BatchNorm to activate/deactivate a learnable shift and scaling - add the keyword `track_stats` to InstanceNorm and BatchNorm to activate/deactivate running time mean and variance tracking, to be used in the evaluation phase. - use pytorch's defaults for `track_stats` and `affine`, since AFAIK they correspond to common usage and best practice: - For InstanceNorm, `affine=false` and `track_stats=false`. This breaks the current behaviour (both `true`) - For BatchNorm, `affine=true` and `track_stats=true`, as we are currently doing - For GroupNorm, `affine=true` and `track_stats=false` - unifies InstanceNorm, GroupNorm, and BatchNorm logics - perf improvement for InstanceNorm that was previously using repetition instead of broadcasting - Fix InstanceNorm previously not working on gpu (Fix #1195) - Fix GroupNorm previously not working on gpu (Fix #1247) - Fix #1280, fix #1308, fix #1429, fix #1430, fix #802 ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
2 parents ecde91b + ff18866 commit 7e9a180

File tree

14 files changed

+584
-370
lines changed

14 files changed

+584
-370
lines changed

Manifest.toml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ version = "0.3.3"
1414

1515
[[Adapt]]
1616
deps = ["LinearAlgebra"]
17-
git-tree-sha1 = "87491f7d03ae1b423a353aff99cf61a45e3c993a"
17+
git-tree-sha1 = "ffcfa2d345aaee0ef3d8346a073d5dd03c983ebe"
1818
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
19-
version = "3.1.0"
19+
version = "3.2.0"
2020

2121
[[Artifacts]]
2222
deps = ["Pkg"]
@@ -93,9 +93,9 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
9393
version = "0.3.4+0"
9494

9595
[[DataAPI]]
96-
git-tree-sha1 = "ad84f52c0b8f05aa20839484dbaf01690b41ff84"
96+
git-tree-sha1 = "6d64b28d291cb94a0d84e6e41081fb081e7f717f"
9797
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
98-
version = "1.4.0"
98+
version = "1.5.0"
9999

100100
[[DataStructures]]
101101
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
@@ -134,9 +134,9 @@ version = "0.1.3"
134134

135135
[[FillArrays]]
136136
deps = ["LinearAlgebra", "Random", "SparseArrays"]
137-
git-tree-sha1 = "8bd8e47ff5d34b20f0aa9641988eb660590008bc"
137+
git-tree-sha1 = "50eabdace27aa27b143f65b65e762bb0112a7708"
138138
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
139-
version = "0.11.0"
139+
version = "0.11.1"
140140

141141
[[FixedPointNumbers]]
142142
deps = ["Statistics"]
@@ -146,9 +146,9 @@ version = "0.8.4"
146146

147147
[[ForwardDiff]]
148148
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
149-
git-tree-sha1 = "c26b56e9b9f0687f7ca887f6b6ded03d269e0e35"
149+
git-tree-sha1 = "d48a40c0f54f29a5c8748cfb3225719accc72b77"
150150
uuid = "f6369f11-7733-5829-9624-2563aa707210"
151-
version = "0.10.15"
151+
version = "0.10.16"
152152

153153
[[Functors]]
154154
deps = ["MacroTools"]
@@ -227,9 +227,9 @@ version = "0.5.0"
227227

228228
[[Missings]]
229229
deps = ["DataAPI"]
230-
git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8"
230+
git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c"
231231
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
232-
version = "0.4.4"
232+
version = "0.4.5"
233233

234234
[[Mmap]]
235235
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
@@ -252,9 +252,9 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
252252
version = "0.5.3+4"
253253

254254
[[OrderedCollections]]
255-
git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db"
255+
git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23"
256256
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
257-
version = "1.3.2"
257+
version = "1.3.3"
258258

259259
[[Pkg]]
260260
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@@ -375,9 +375,9 @@ version = "1.2.11+18"
375375

376376
[[Zygote]]
377377
deps = ["AbstractFFTs", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
378-
git-tree-sha1 = "746c9de7fb87a341c809437007cbd172c4d494b4"
378+
git-tree-sha1 = "52835a83f7c899cfcb95f796d584201812887ea8"
379379
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
380-
version = "0.6.2"
380+
version = "0.6.3"
381381

382382
[[ZygoteRules]]
383383
deps = ["MacroTools"]

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
* Add [sparse initialization](https://github.com/FluxML/Flux.jl/pull/1454) as described in [Deep learning via Hessian-free optimization](https://dl.acm.org/doi/abs/10.5555/3104322.3104416).
1212
* Moved GPU CI to use buildkite instead of GitLab
1313
* New [`Parallel` layer](https://github.com/FluxML/Flux.jl/pull/1462) adds inception module-like building blocks.
14-
14+
* Feature additions and bug fixes for BatchNorm, LayerNorm, InstanceNorm, and GroupNorm [normalization layers](https://github.com/FluxML/Flux.jl/pull/1397)
1515

1616
## v0.11.2
1717

docs/src/models/basics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,4 @@ Flux.@functor Affine
218218

219219
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
220220

221-
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).
221+
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).

src/cuda/cudnn.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
import CUDA.CUDNN: batchnorm, ∇batchnorm
22

3-
(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
4-
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
3+
function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
4+
cache=nothing) where T<:Union{Float32, Float64}
5+
6+
@assert BN.affine "BatchNorm: only affine=true supported on gpu"
7+
@assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu"
8+
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wronng number of channels"
9+
return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
10+
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
11+
training=Flux._isactive(BN)))
12+
end
513

6-
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
7-
batchnorm(g, b, x, running_mean, running_var, momentum; kw...), Δ -> (∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing)
14+
@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
15+
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
16+
function batchnorm_pullback(Δ)
17+
∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing
18+
end
19+
y, batchnorm_pullback
20+
end

src/deprecations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# v0.12 deprecations
22
@deprecate Dropout(p, dims) Dropout(p; dims=dims)
3-
@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
4-
@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
3+
@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing)
4+
@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing)
55
@deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing)
66
@deprecate outdims(f, inputsize) outputsize(f, inputsize)
77
@deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...)

src/layers/basic.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,32 +134,35 @@ function Base.show(io::IO, l::Dense)
134134
end
135135

136136
"""
137-
Diagonal(in::Integer)
137+
Diagonal(α, β)
138+
Diagonal(sz::Integer...; initα=ones, initβ=zeros)
138139
139-
Create an element-wise linear transformation layer with learnable
140-
vectors `α` and `β`:
140+
Create an element-wise linear layer with learnable
141+
arrays `α` and `β` of size `sz`. The layer performs
141142
142143
y = α .* x .+ β
143144
144-
The input `x` must be a array where `size(x, 1) == in`.
145+
The input `x` must have size broadcast-compatible with `α` and `β`.
146+
The parameters will be created with the calls
147+
`α = initα(sz)` and `β = initβ(sz)`.
145148
"""
146149
struct Diagonal{T}
147150
α::T
148151
β::T
149152
end
150153

151-
Diagonal(in::Integer; initα = ones, initβ = zeros) =
152-
Diagonal(initα(in), initβ(in))
154+
function Diagonal(sz::Integer...;
155+
initα = i -> ones(Float32, i),
156+
initβ = i -> zeros(Float32, i))
157+
Diagonal(initα(sz), initβ(sz))
158+
end
153159

154160
@functor Diagonal
155161

156-
function (a::Diagonal)(x)
157-
α, β = a.α, a.β
158-
α.*x .+ β
159-
end
162+
(a::Diagonal)(x) = a.α .* x .+ a.β
160163

161164
function Base.show(io::IO, l::Diagonal)
162-
print(io, "Diagonal(", length(l.α), ")")
165+
print(io, "Diagonal(", size(l.α), ")")
163166
end
164167

165168
"""

0 commit comments

Comments
 (0)