Skip to content

Commit ccf1732

Browse files
authored
fast_maximum also for logsumexp (#456)
In the spirit of #450 ...
1 parent 591ac09 commit ccf1732

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/softmax.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ Without `dims` keyword this returns a scalar.
141141
See also [`logsoftmax`](@ref).
142142
"""
143143
function logsumexp(x::AbstractArray; dims = :)
144-
max_ = maximum(x; dims)
144+
max_ = fast_maximum(x; dims)
145145
@fastmath max_ .+ log.(sum(exp.(x .- max_); dims))
146146
end
147147

148148
function rrule(::typeof(logsumexp), x; dims = :)
149149
# The gradient is `softmax`, but both compute `tmp` so it's worth saving.
150-
max_ = maximum(x; dims)
150+
max_ = fast_maximum(x; dims)
151151
@fastmath tmp = exp.(x .- max_)
152152
@fastmath y = max_ .+ log.(sum(tmp; dims))
153153
logsumexp_pullback(dy) = (NoTangent(), unthunk(dy) .* tmp ./ sum(tmp; dims))

0 commit comments

Comments
 (0)