Skip to content

Commit 91f2d47

Browse files
Epsilon change in normalise for stability (#2421)
* epsilon change for stability * Change comment for eps Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com> --------- Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com> Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>
1 parent e1989b5 commit 91f2d47

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/layers/stateless.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11

22
"""
3-
normalise(x; dims=ndims(x), eps=1e-5)
3+
normalise(x; dims=ndims(x), eps=1f-5)
44
55
Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`.
66
Per default, `dims` is the last dimension.
7-
`eps` is a small term added to the denominator for numerical stability.
7+
`eps` is a small term added to the variance for numerical stability.
88
99
# Examples
1010
```jldoctest
@@ -34,10 +34,11 @@ julia> isapprox(std(y; dims=1, corrected=false), ones(1, 10), atol=1e-5)
3434
true
3535
```
3636
"""
37-
@inline function normalise(x::AbstractArray; dims=ndims(x), eps=ofeltype(x, 1e-5))
37+
@inline function normalise(x::AbstractArray; dims=ndims(x), eps=1f-5)
3838
μ = mean(x, dims=dims)
39-
σ = std(x, dims=dims, mean=μ, corrected=false)
40-
return @. (x - μ) /+ eps)
39+
σ² = var(x, dims=dims, mean=μ, corrected=false)
40+
ε = ofeltype(x, eps)
41+
return @. (x - μ) / sqrt(σ² + ε^2)
4142
end
4243

4344
"""

0 commit comments

Comments
 (0)