Skip to content

Commit 1f06aa6

Browse files
committed
fix broadcast
1 parent feafacd commit 1f06aa6

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/multivariate/dirichlet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,9 @@ function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}
382382
Δalpha = ChainRulesCore.unthunk(Δalpha)
383383
∂alpha0 = sum(Δalpha)
384384
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
385-
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha))) do Δalpha_i, alpha_i
385+
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalpha_i, alpha_i
386386
Δalpha_i * (SpecialFunctions.digamma(alpha_i) - digamma_alpha0)
387-
end
387+
end))
388388
backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
389389
t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing)
390390
return d, t
@@ -406,9 +406,9 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x:
406406
if !insupport(d, x)
407407
return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB)))
408408
end
409-
∂α_x = sum(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x)) do Δalpha_i, Δx_i, alpha_i, x_i
409+
∂α_x = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i
410410
xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i
411-
end
411+
end))
412412
∂l = - Δd.lmnB
413413
return (lp, ∂α_x + ∂l)
414414
end

0 commit comments

Comments
 (0)