Skip to content

Commit 72cda4a

Browse files
committed
Improve type stability of LayerNorm and Dropout
1 parent d66d2c4 commit 72cda4a

File tree

4 files changed

+88
-25
lines changed

4 files changed

+88
-25
lines changed

src/Flux.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ using MacroTools: @forward
99
using MLUtils
1010
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
1111

12-
using Zygote, ChainRulesCore
12+
using ChainRulesCore
13+
14+
using Zygote
1315
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1416
export gradient
1517

src/layers/normalise.jl

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ istraining() = false
22

33
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
44

5-
_isactive(m) = isnothing(m.active) ? istraining() : m.active
5+
_isactive(m) = isnothing(m.active) ? istraining() : Bool(m.active)
6+
7+
ChainRulesCore.@non_differentiable _isactive(::Any)
68

79
_dropout_shape(s, ::Colon) = size(s)
810
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
@@ -29,26 +31,51 @@ automatically managed using the [`Dropout`](@ref) layer instead of the
2931
3032
The [`Dropout`](@ref) layer is what you should use in most scenarios.
3133
"""
32-
function dropout(rng, x, p; dims=:, active::Bool=true)
33-
active || return x
34-
y = dropout_mask(rng, x, p, dims=dims)
35-
return x .* y
36-
end
34+
dropout(rng, x, p; dims=:, active::Bool=true) = _dropout(rng, x, p, dims, active)
3735
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
3836

39-
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
40-
dropout_mask(rng, x::CuArray, p; kwargs...) =
41-
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
42-
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
43-
function _dropout_mask(rng, x, p; dims=:)
37+
# Internal function without kwargs to keep Zygote generated code type stable
38+
function _dropout(rng, x, p, dims, active)
39+
mask = active ? dropout_mask(rng, x, p, dims) : nothing
40+
return _apply_mask(x, mask)
41+
end
42+
43+
function ChainRulesCore.rrule(::typeof(_dropout), rng, x, p, dims, active)
44+
mask = active ? dropout_mask(rng, x, p, dims) : nothing
45+
# Required because we don't always call dropout_mask
46+
MT = Core.Compiler.return_type(dropout_mask, Tuple{typeof(rng),typeof(x),typeof(p),typeof(dims)})
47+
project_x = ProjectTo(x)
48+
return _apply_mask(x, mask), DropoutPullback{MT,typeof(project_x)}(mask, project_x)
49+
end
50+
51+
# Also needed for type stability. Otherwise inference lifts the Union into a
52+
# Union{pullback{Nothing}, pullback{AbstractArray}}
53+
struct DropoutPullback{M<:AbstractArray,P<:ProjectTo{AbstractArray}}
54+
mask::Union{Nothing,M}
55+
project::P
56+
end
57+
58+
function (pb::DropoutPullback)(dy)
59+
dx = pb.project(_apply_mask(dy, pb.mask))
60+
return (NoTangent(), NoTangent(), dx, NoTangent())
61+
end
62+
63+
_apply_mask(x, ::Nothing) = x
64+
_apply_mask(x, mask) = x .* mask
65+
66+
dropout_mask(rng::CUDA.RNG, x::CuArray, p, dims) = _dropout_mask(rng, x, p, dims)
67+
dropout_mask(rng, x::CuArray, p, dims) =
68+
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only supports CUDA.RNG for CuArrays."))
69+
dropout_mask(rng, x, p, dims) = _dropout_mask(rng, x, p, dims)
70+
function _dropout_mask(rng, x, p, dims)
4471
realfptype = float(real(eltype(x)))
4572
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
4673
y .= _dropout_kernel.(y, p, 1 - p)
4774
return y
4875
end
4976

5077
# TODO move this to NNlib
51-
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
78+
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any, ::Any)
5279

5380
"""
5481
Dropout(p; dims=:, rng = rng_from_array())
@@ -106,10 +133,7 @@ end
106133
@functor Dropout
107134
trainable(a::Dropout) = (;)
108135

109-
function (a::Dropout)(x)
110-
_isactive(a) || return x
111-
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
112-
end
136+
(a::Dropout)(x) = _dropout(a.rng, x, a.p, a.dims, _isactive(a))
113137

114138
testmode!(m::Dropout, mode=true) =
115139
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
@@ -226,7 +250,7 @@ LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]
226250

227251
@functor LayerNorm
228252

229-
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
253+
(a::LayerNorm)(x) = a.diag(_normalize(x, 1:length(a.size), a.ϵ))
230254

231255
function Base.show(io::IO, l::LayerNorm)
232256
print(io, "LayerNorm(", join(l.size, ", "))

src/layers/stateless.jl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,41 @@ function flatten(x::AbstractArray)
2626
return reshape(x, :, size(x)[end])
2727
end
2828

29+
# Utils for LayerNorm internals.
30+
# Most of these are required for better performance and type stability under AD.
31+
# In an ideal world, we'd just have normalise.
32+
33+
function _mean_std(x::AbstractArray, dims)
34+
μ = mean(x, dims=dims)
35+
σ = std(x, dims=dims, mean=μ, corrected=false)
36+
return μ, σ
37+
end
38+
39+
function ChainRulesCore.rrule(::typeof(_mean_std), x::AbstractArray, dims)
40+
μ, mean_pullback = ChainRulesCore.rrule(mean, x, dims=dims)
41+
σ, std_pullback = ChainRulesCore.rrule(std, x, dims=dims, mean=μ, corrected=false)
42+
function _mean_std_pullback((dμ, dσ))
43+
dx = ChainRulesCore.add!!(std_pullback(dσ)[2], mean_pullback(dμ)[2])
44+
return (NoTangent(), dx, NoTangent())
45+
end
46+
47+
return (μ, σ), _mean_std_pullback
48+
end
49+
50+
_zscore(x, μ, σ, ϵ) = (x - μ) /+ ϵ)
51+
52+
# We don't define a rrule for the whole function because we want
53+
# AD to figure out the _zscore broadcast for us.
54+
function _normalize(x::AbstractArray, dims, ϵ)
55+
μ, σ = _mean_std(x, dims)
56+
return _zscore.(x, μ, σ, ϵ)
57+
end
58+
2959
"""
3060
normalise(x; dims=ndims(x), ϵ=1e-5)
3161
3262
Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`.
33-
Per default, `dims` is the last dimension.
63+
Per default, `dims` is the last dimension.
3464
`ϵ` is a small additive factor added to the denominator for numerical stability.
3565
"""
36-
@inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
37-
μ = mean(x, dims=dims)
38-
σ = std(x, dims=dims, mean=μ, corrected=false)
39-
return @. (x - μ) /+ ϵ)
40-
end
66+
@inline normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5)) = _normalize(x, dims, ϵ)

test/layers/normalisation.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
7373
@test cpu(m).rng === only(values(rng_kwargs))
7474
end
7575
end
76+
77+
for active in (true, false)
78+
m = Dropout(0.5, :, active)
79+
@inferred _, back = pullback(m, rand(10)) # _, DropoutPullback{Array{Float64}}
80+
@inferred back(ones(10)) # Array{Float64}
81+
end
7682
end
7783

7884
@testset "AlphaDropout" begin
@@ -343,8 +349,13 @@ end
343349
@test LayerNorm(2)(x) Flux.normalise(x, dims=1)
344350
x = rand(2,3,4,5)
345351
@test LayerNorm(2)(x) Flux.normalise(x, dims=1)
352+
346353
x = rand(2)
347-
@test LayerNorm(2, tanh)(x) tanh.(Flux.normalise(x, dims=1))
354+
m = LayerNorm(2, tanh)
355+
@test m(x) tanh.(Flux.normalise(x, dims=1))
356+
@inferred _, back = pullback(summ, x)
357+
@inferred back(1.0)
358+
348359

349360
x = rand(2,3,4,5)
350361
@test LayerNorm((2,3))(x) Flux.normalise(x, dims=(1,2))

0 commit comments

Comments
 (0)