We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
fast_maximum
logsumexp
1 parent 591ac09 commit ccf1732Copy full SHA for ccf1732
src/softmax.jl
@@ -141,13 +141,13 @@ Without `dims` keyword this returns a scalar.
141
See also [`logsoftmax`](@ref).
142
"""
143
function logsumexp(x::AbstractArray; dims = :)
144
- max_ = maximum(x; dims)
+ max_ = fast_maximum(x; dims)
145
@fastmath max_ .+ log.(sum(exp.(x .- max_); dims))
146
end
147
148
function rrule(::typeof(logsumexp), x; dims = :)
149
# The gradient is `softmax`, but both compute `tmp` so it's worth saving.
150
151
@fastmath tmp = exp.(x .- max_)
152
@fastmath y = max_ .+ log.(sum(tmp; dims))
153
logsumexp_pullback(dy) = (NoTangent(), unthunk(dy) .* tmp ./ sum(tmp; dims))
0 commit comments