Skip to content

Commit 0500772

Browse files
committed
single loop
1 parent 4cc7509 commit 0500772

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/multivariate/dirichlet.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}
381381
d = DT(alpha; check_args=check_args)
382382
Δalpha = ChainRulesCore.unthunk(Δalpha)
383383
∂alpha0 = sum(Δalpha)
384-
∂lmnB = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - SpecialFunctions.digamma(d.alpha0)) for i in eachindex(alpha))
384+
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
385+
∂lmnB = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - digamma_alpha0) for i in eachindex(alpha))
385386
backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
386387
t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing)
387388
return d, t
@@ -402,13 +403,14 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x:
402403
if !insupport(d, x)
403404
return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB)))
404405
end
405-
∂α = sum(Δd.alpha[i] * log(x[i]) for i in eachindex(x))
406+
∂α_x = sum(eachindex(x)) do i
407+
xlogy(Δd.alpha[i], log(x[i])) + (d.alpha[i] - 1) * Δx[i] / x[i]
408+
end
406409
∂l = - Δd.lmnB
407-
∂x = sum((d.alpha[i] - 1) * Δx[i] / x[i] for i in eachindex(x))
408-
return (lp, ∂α + ∂l + ∂x)
410+
return (lp, ∂α_x + ∂l)
409411
end
410412

411-
function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T <: Real}
413+
function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
412414
y = _logpdf(d, x)
413415
function Dirichlet_logpdf_pullback(dy)
414416
if !isfinite(y)
@@ -417,7 +419,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector
417419
∂x = zero(d.alpha + x)
418420
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
419421
end
420-
∂alpha = dy * log.(x)
422+
∂alpha = xlogxy.(dy, x)
421423
∂l = -dy
422424
∂x = dy * (d.alpha .-1) ./ x
423425
backing = (alpha = ∂alpha, alpha0 = ChainRulesCore.ZeroTangent(), lmnB=∂l)

0 commit comments

Comments
 (0)