Skip to content

Commit 25f0a1b

Browse files
committed
Improve type stability of LayerNorm and Dropout
1 parent 952c4a5 commit 25f0a1b

File tree

3 files changed

+74
-23
lines changed

3 files changed

+74
-23
lines changed

src/Flux.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ using MacroTools: @forward
99
using MLUtils
1010
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
1111

12-
using Zygote, ChainRulesCore
12+
using ChainRulesCore
13+
14+
using Zygote
1315
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1416
export gradient
1517

src/layers/normalise.jl

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ istraining() = false
22

33
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
44

5-
_isactive(m) = isnothing(m.active) ? istraining() : m.active
5+
_isactive(m) = isnothing(m.active) ? istraining() : Bool(m.active)
6+
7+
ChainRulesCore.@non_differentiable _isactive(::Any)
68

79
_dropout_shape(s, ::Colon) = size(s)
810
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
@@ -31,26 +33,50 @@ automatically managed using the [`Dropout`](@ref) layer instead of the
3133
3234
The [`Dropout`](@ref) layer is what you should use in most scenarios.
3335
"""
34-
function dropout(rng, x, p; dims=:, active::Bool=true)
35-
active || return x
36-
y = dropout_mask(rng, x, p, dims=dims)
37-
return x .* y
38-
end
36+
dropout(rng, x, p; dims=:, active::Bool=true) = _dropout(rng, x, p; dims, active)
3937
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
4038

41-
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
42-
dropout_mask(rng, x::CuArray, p; kwargs...) =
39+
# Internal function without kwargs to keep Zygote generated code type stable
40+
function _dropout(rng, x, p, dims, active)
41+
mask = active ? dropout_mask(rng, x, p, dims=dims) : nothing
42+
return _apply_mask(x, mask)
43+
end
44+
45+
function ChainRulesCore.rrule(::typeof(_dropout), rng, x, p, dims, active)
46+
mask = active ? dropout_mask(rng, x, p, dims=dims) : nothing
47+
MT = Core.Compiler.return_type(dropout_mask, Tuple{typeof(rng),typeof(x),typeof(p),typeof(dims)})
48+
project_x = ProjectTo(x)
49+
return _apply_mask(x, mask), DropoutPullback{MT,typeof(project_x)}(mask, project_x)
50+
end
51+
52+
# Also needed for type stability. Otherwise inference lifts the Union into a
53+
# Union{Pullback{Nothing}, Pullback{AbstractArray}}
54+
struct DropoutPullback{M<:AbstractArray,P<:ProjectTo{AbstractArray}}
55+
mask::Union{Nothing,M}
56+
project::P
57+
end
58+
59+
function (pb::DropoutPullback)(dy)
60+
dx = pb.project(_apply_mask(dy, pb.mask))
61+
return (NoTangent(), NoTangent(), dx, NoTangent())
62+
end
63+
64+
_apply_mask(x, ::Nothing) = x
65+
_apply_mask(x, mask) = x .* mask
66+
67+
dropout_mask(rng::CUDA.RNG, x::CuArray, p, dims) = _dropout_mask(rng, x, p, dims)
68+
dropout_mask(rng, x::CuArray, p, dims) =
4369
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
44-
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
45-
function _dropout_mask(rng, x, p; dims=:)
70+
dropout_mask(rng, x, p, dims) = _dropout_mask(rng, x, p, dims)
71+
function _dropout_mask(rng, x, p, dims)
4672
realfptype = float(real(eltype(x)))
4773
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
4874
y .= _dropout_kernel.(y, p, 1 - p)
4975
return y
5076
end
5177

5278
# TODO move this to NNlib
53-
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
79+
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any, ::Any)
5480

5581
"""
5682
Dropout(p; dims=:, rng = rng_from_array())
@@ -82,10 +108,7 @@ end
82108
@functor Dropout
83109
trainable(a::Dropout) = (;)
84110

85-
function (a::Dropout)(x)
86-
_isactive(a) || return x
87-
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
88-
end
111+
(a::Dropout)(x) = _dropout(a.rng, x, a.p, a.dims, _isactive(a))
89112

90113
testmode!(m::Dropout, mode=true) =
91114
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
@@ -172,7 +195,7 @@ LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]
172195

173196
@functor LayerNorm
174197

175-
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
198+
(a::LayerNorm)(x) = a.diag(_normalize(x, 1:length(a.size), a.ϵ))
176199

177200
function Base.show(io::IO, l::LayerNorm)
178201
print(io, "LayerNorm(", join(l.size, ", "))

src/layers/stateless.jl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,41 @@ function flatten(x::AbstractArray)
2626
return reshape(x, :, size(x)[end])
2727
end
2828

29+
# Utils for LayerNorm internals.
30+
# Most of these are required for better performance and type stability under AD.
31+
# In an ideal world, we'd just have normalise.
32+
33+
function _mean_std(x::AbstractArray, dims)
34+
μ = mean(x, dims=dims)
35+
σ = std(x, dims=dims, mean=μ, corrected=false)
36+
return μ, σ
37+
end
38+
39+
function ChainRulesCore.rrule(::typeof(_mean_std), x::AbstractArray, dims)
40+
μ, mean_pullback = ChainRulesCore.rrule(mean, x, dims=dims)
41+
σ, std_pullback = ChainRulesCore.rrule(std, x, dims=dims, mean=μ, corrected=false)
42+
function _mean_std_pullback((dμ, dσ))
43+
dx = ChainRulesCore.add!!(std_pullback(dσ)[2], mean_pullback(dμ)[2])
44+
return (NoTangent(), dx, NoTangent())
45+
end
46+
47+
return (μ, σ), _mean_std_pullback
48+
end
49+
50+
_zscore(x, μ, σ, ϵ) = (x - μ) /+ ϵ)
51+
52+
# We don't define a rrule for the whole function because we want
53+
# AD to figure out the _zscore broadcast for us.
54+
function _normalize(x::AbstractArray, dims, ϵ)
55+
μ, σ = _mean_std(x, dims)
56+
return _zscore.(x, μ, σ, ϵ)
57+
end
58+
2959
"""
3060
normalise(x; dims=ndims(x), ϵ=1e-5)
3161
3262
Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`.
33-
Per default, `dims` is the last dimension.
63+
Per default, `dims` is the last dimension.
3464
`ϵ` is a small additive factor added to the denominator for numerical stability.
3565
"""
36-
@inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
37-
μ = mean(x, dims=dims)
38-
σ = std(x, dims=dims, mean=μ, corrected=false)
39-
return @. (x - μ) /+ ϵ)
40-
end
66+
@inline normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5)) = _normalize(x, dims, ϵ)

0 commit comments

Comments
 (0)