@@ -382,9 +382,9 @@ function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}
382
382
Δalpha = ChainRulesCore. unthunk (Δalpha)
383
383
∂alpha0 = sum (Δalpha)
384
384
digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
385
- ∂lmnB = sum (Broadcast. instantiate (Broadcast. broadcasted (Δalpha, alpha))) do Δalpha_i, alpha_i
385
+ ∂lmnB = sum (Broadcast. instantiate (Broadcast. broadcasted (Δalpha, alpha) do Δalpha_i, alpha_i
386
386
Δalpha_i * (SpecialFunctions. digamma (alpha_i) - digamma_alpha0)
387
- end
387
+ end ))
388
388
backing = (alpha= Δalpha, alpha0= ∂alpha0, lmnB= ∂lmnB)
389
389
t = ChainRulesCore. Tangent {typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}} (backing)
390
390
return d, t
@@ -406,9 +406,9 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x:
406
406
if ! insupport (d, x)
407
407
return (lp, zero (lp) + zero (eltype (Δx)) + zero (eltype (Δd. alpha)) + zero (eltype (Δd. lmnB)))
408
408
end
409
- ∂α_x = sum (Broadcast. broadcasted (Δd. alpha, Δx, d. alpha, x) ) do Δalpha_i, Δx_i, alpha_i, x_i
409
+ ∂α_x = sum (Broadcast. instantiate (Broadcast . broadcasted (Δd. alpha, Δx, d. alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i
410
410
xlogy (Δalpha_i, x_i) + (alpha_i - 1 ) * Δx_i / x_i
411
- end
411
+ end ))
412
412
∂l = - Δd. lmnB
413
413
return (lp, ∂α_x + ∂l)
414
414
end
0 commit comments