Skip to content

Commit feafacd

Browse files
committed
switch to broadcast
1 parent 77ccee6 commit feafacd

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/multivariate/dirichlet.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,9 @@ function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}
382382
Δalpha = ChainRulesCore.unthunk(Δalpha)
383383
∂alpha0 = sum(Δalpha)
384384
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
385-
∂lmnB = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - digamma_alpha0) for i in eachindex(alpha))
385+
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha))) do Δalpha_i, alpha_i
386+
Δalpha_i * (SpecialFunctions.digamma(alpha_i) - digamma_alpha0)
387+
end
386388
backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
387389
t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing)
388390
return d, t
@@ -392,7 +394,8 @@ function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::
392394
d = DT(alpha; check_args=check_args)
393395
function dirichlet_pullback(d_dir)
394396
d_dir = ChainRulesCore.unthunk(d_dir)
395-
dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- SpecialFunctions.digamma.(d.alpha0))
397+
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
398+
dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
396399
return ChainRulesCore.NoTangent(), dalpha
397400
end
398401
return d, dirichlet_pullback
@@ -403,8 +406,8 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x:
403406
if !insupport(d, x)
404407
return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB)))
405408
end
406-
∂α_x = sum(eachindex(x)) do i
407-
xlogy(Δd.alpha[i], x[i]) + (d.alpha[i] - 1) * Δx[i] / x[i]
409+
∂α_x = sum(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x)) do Δalpha_i, Δx_i, alpha_i, x_i
410+
xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i
408411
end
409412
∂l = - Δd.lmnB
410413
return (lp, ∂α_x + ∂l)

0 commit comments

Comments
 (0)