@@ -382,7 +382,9 @@ function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}
382
382
Δalpha = ChainRulesCore. unthunk (Δalpha)
383
383
∂alpha0 = sum (Δalpha)
384
384
digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
385
- ∂lmnB = sum (Δalpha[i] * (SpecialFunctions. digamma (alpha[i]) - digamma_alpha0) for i in eachindex (alpha))
385
+ ∂lmnB = sum (Broadcast. instantiate (Broadcast. broadcasted (Δalpha, alpha))) do Δalpha_i, alpha_i
386
+ Δalpha_i * (SpecialFunctions. digamma (alpha_i) - digamma_alpha0)
387
+ end
386
388
backing = (alpha= Δalpha, alpha0= ∂alpha0, lmnB= ∂lmnB)
387
389
t = ChainRulesCore. Tangent {typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}} (backing)
388
390
return d, t
@@ -392,7 +394,8 @@ function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::
392
394
d = DT (alpha; check_args= check_args)
393
395
function dirichlet_pullback (d_dir)
394
396
d_dir = ChainRulesCore. unthunk (d_dir)
395
- dalpha = d_dir. alpha .+ d_dir. alpha0 .+ d_dir. lmnB .* (SpecialFunctions. digamma .(alpha) .- SpecialFunctions. digamma .(d. alpha0))
397
+ digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
398
+ dalpha = d_dir. alpha .+ d_dir. alpha0 .+ d_dir. lmnB .* (SpecialFunctions. digamma .(alpha) .- digamma_alpha0)
396
399
return ChainRulesCore. NoTangent (), dalpha
397
400
end
398
401
return d, dirichlet_pullback
@@ -403,8 +406,8 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x:
403
406
if ! insupport (d, x)
404
407
return (lp, zero (lp) + zero (eltype (Δx)) + zero (eltype (Δd. alpha)) + zero (eltype (Δd. lmnB)))
405
408
end
406
- ∂α_x = sum (eachindex ( x)) do i
407
- xlogy (Δd . alpha[i], x[i] ) + (d . alpha[i] - 1 ) * Δx[i] / x[i]
409
+ ∂α_x = sum (Broadcast . broadcasted (Δd . alpha, Δx, d . alpha, x)) do Δalpha_i, Δx_i, alpha_i, x_i
410
+ xlogy (Δalpha_i, x_i ) + (alpha_i - 1 ) * Δx_i / x_i
408
411
end
409
412
∂l = - Δd. lmnB
410
413
return (lp, ∂α_x + ∂l)
0 commit comments