Skip to content

Commit bcb7ba9

Browse files
authored
Improvements for LayerNorm (#1911)
1 parent 3c935cc commit bcb7ba9

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

src/layers/normalise.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,22 +164,21 @@ struct LayerNorm{F,D,T,N}
164164
affine::Bool
165165
end
166166

167-
function LayerNorm(sz, λ=identity; affine=true, ϵ=1f-5)
168-
sz = sz isa Integer ? (sz,) : sz
169-
diag = affine ? Diagonal(sz...) : nothing
170-
return LayerNorm(λ, diag, ϵ, sz, affine)
167+
function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
168+
diag = affine ? Diagonal(sz...) : identity
169+
return LayerNorm(λ, diag, ϵ, Tuple(sz), affine)
171170
end
172171

173172
@functor LayerNorm
174173

175174
function (a::LayerNorm)(x)
176-
x = normalise(x, dims=1:length(a.size), ϵ=a.ϵ)
177-
a.diag === nothing ? a.λ.(x) : a.λ.(a.diag(x))
175+
x = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
176+
return a.λ === identity ? x : a.λ.(x)
178177
end
179178

180179
function Base.show(io::IO, l::LayerNorm)
181180
print(io, "LayerNorm($(l.size)")
182-
l.λ == identity || print(io, ", $(l.λ)")
181+
l.λ === identity || print(io, ", ", l.λ)
183182
hasaffine(l) || print(io, ", affine=false")
184183
print(io, ")")
185184
end

src/layers/stateless.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given b
3333
Per default, `dims` is the last dimension.
3434
`ϵ` is a small additive factor added to the denominator for numerical stability.
3535
"""
36-
function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
36+
@inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
3737
μ = mean(x, dims=dims)
38-
# σ = std(x, dims=dims, mean=μ, corrected=false) # use this when Zygote#478 gets merged
39-
σ = std(x, dims=dims, corrected=false)
40-
return (x .- μ) ./.+ ϵ)
38+
σ = std(x, dims=dims, mean=μ, corrected=false)
39+
return @. (x - μ) /+ ϵ)
4140
end

0 commit comments

Comments
 (0)