@@ -381,7 +381,8 @@ function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}
381
381
d = DT (alpha; check_args= check_args)
382
382
Δalpha = ChainRulesCore. unthunk (Δalpha)
383
383
∂alpha0 = sum (Δalpha)
384
- ∂lmnB = sum (Δalpha[i] * (SpecialFunctions. digamma (alpha[i]) - SpecialFunctions. digamma (d. alpha0)) for i in eachindex (alpha))
384
+ digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
385
+ ∂lmnB = sum (Δalpha[i] * (SpecialFunctions. digamma (alpha[i]) - digamma_alpha0) for i in eachindex (alpha))
385
386
backing = (alpha= Δalpha, alpha0= ∂alpha0, lmnB= ∂lmnB)
386
387
t = ChainRulesCore. Tangent {typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}} (backing)
387
388
return d, t
@@ -402,13 +403,14 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x:
402
403
if ! insupport (d, x)
403
404
return (lp, zero (lp) + zero (eltype (Δx)) + zero (eltype (Δd. alpha)) + zero (eltype (Δd. lmnB)))
404
405
end
405
- ∂α = sum (Δd. alpha[i] * log (x[i]) for i in eachindex (x))
406
+ ∂α_x = sum (eachindex (x)) do i
407
+ xlogy (Δd. alpha[i], log (x[i])) + (d. alpha[i] - 1 ) * Δx[i] / x[i]
408
+ end
406
409
∂l = - Δd. lmnB
407
- ∂x = sum ((d. alpha[i] - 1 ) * Δx[i] / x[i] for i in eachindex (x))
408
- return (lp, ∂α + ∂l + ∂x)
410
+ return (lp, ∂α_x + ∂l)
409
411
end
410
412
411
- function ChainRulesCore. rrule (:: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{T} ) where {T <: Real }
413
+ function ChainRulesCore. rrule (:: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<: Real} )
412
414
y = _logpdf (d, x)
413
415
function Dirichlet_logpdf_pullback (dy)
414
416
if ! isfinite (y)
@@ -417,7 +419,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector
417
419
∂x = zero (d. alpha + x)
418
420
return (ChainRulesCore. NoTangent (), ∂d, ∂x)
419
421
end
420
- ∂alpha = dy * log .( x)
422
+ ∂alpha = xlogxy .(dy, x)
421
423
∂l = - dy
422
424
∂x = dy * (d. alpha .- 1 ) ./ x
423
425
backing = (alpha = ∂alpha, alpha0 = ChainRulesCore. ZeroTangent (), lmnB= ∂l)
0 commit comments