Skip to content

Commit 07f67c8

Browse files
authored
Avoid maximum in softmax (#450)
* avoid maximum in softmax * fixup * filter doctest * do the same for logsoftmax * revert a mistake * does anyone ever remember what syntax documenter does and doesn't like for more than 5 minutes? * maybe a different thing was wrong with documenter and I was mislead by what it complained about
1 parent 57268d1 commit 07f67c8

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/softmax.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ See also [`logsoftmax`](@ref).
1717
1818
# Examples
1919
20-
```jldoctest
20+
```jldoctest; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
2121
julia> softmax([1, 2, 3])
2222
3-element Vector{Float64}:
2323
0.09003057317038046
@@ -58,13 +58,14 @@ softmax(x::AbstractArray{T}; dims = 1) where {T} = softmax!(similar(x, float(T))
5858
softmax!(x::AbstractArray; dims = 1) = softmax!(x, x; dims)
5959

6060
function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
61-
max_ = maximum(x; dims)
61+
max_ = fast_maximum(x; dims)
6262
if all(isfinite, max_)
6363
@fastmath out .= exp.(x .- max_)
6464
else
6565
@fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
6666
end
67-
out ./= sum(out; dims)
67+
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
68+
out ./= tmp
6869
end
6970

7071
function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}
@@ -75,7 +76,7 @@ function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) wh
7576
# This path is faster, only safe for 1st derivatives though.
7677
# Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads,
7778
# but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30
78-
out = similar(y, promote_type(T,S))
79+
out = similar(y, promote_type(T,S)) # sure to be mutable
7980
out .= dy .* y
8081
out .= out .- y .* sum(out; dims)
8182
end
@@ -90,6 +91,7 @@ end
9091
within_grad() = false
9192
rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),)
9293

94+
fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))
9395

9496
"""
9597
logsoftmax(x; dims = 1)
@@ -109,7 +111,7 @@ logsoftmax(x::AbstractArray{T}; dims = 1) where {T} = logsoftmax!(similar(x, flo
109111
logsoftmax!(x::AbstractArray; dims = 1) = logsoftmax!(x, x; dims)
110112

111113
function logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
112-
max_ = maximum(x; dims)
114+
max_ = fast_maximum(x; dims)
113115
if all(isfinite, max_)
114116
out .= x .- max_
115117
else

0 commit comments

Comments
 (0)