@@ -403,29 +403,30 @@ end
403
403
404
404
function ChainRulesCore. frule ((_, Δd, Δx), :: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
405
405
lp = _logpdf (d, x)
406
- if ! insupport (d, x)
407
- return (lp, zero (lp) + zero (eltype (Δx)) + zero (eltype (Δd. alpha)) + zero (eltype (Δd. lmnB)))
408
- end
409
406
∂α_x = sum (Broadcast. instantiate (Broadcast. broadcasted (Δd. alpha, Δx, d. alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i
410
407
xlogy (Δalpha_i, x_i) + (alpha_i - 1 ) * Δx_i / x_i
411
408
end ))
412
- ∂l = - Δd. lmnB
409
+ ∂l = - Δd. lmnB
410
+ if ! insupport (d, x)
411
+ ∂α_x = oftype (∂α_x, NaN )
412
+ end
413
413
return (lp, ∂α_x + ∂l)
414
414
end
415
415
416
416
function ChainRulesCore. rrule (:: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
417
417
y = _logpdf (d, x)
418
418
function Dirichlet_logpdf_pullback (dy)
419
- if ! isfinite (y)
420
- backing = (alpha = zero (d. alpha), alpha0 = ChainRulesCore. ZeroTangent (), lmnB= zero (d. lmnB))
421
- ∂d = ChainRulesCore. Tangent {typeof(d), typeof(backing)} (backing)
422
- ∂x = zero (d. alpha + x)
423
- return (ChainRulesCore. NoTangent (), ∂d, ∂x)
424
- end
425
419
∂alpha = xlogy .(dy, x)
426
420
∂l = - dy
427
421
∂x = dy * (d. alpha .- 1 ) ./ x
428
- backing = (alpha = ∂alpha, alpha0 = ChainRulesCore. ZeroTangent (), lmnB= ∂l)
422
+ ∂alpha0 = 0.0
423
+ if ! isfinite (y)
424
+ ∂alpha .= NaN
425
+ ∂l = oftype (∂l, NaN )
426
+ ∂x .= NaN
427
+ ∂alpha0 = NaN
428
+ end
429
+ backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB= ∂l)
429
430
∂d = ChainRulesCore. Tangent {typeof(d), typeof(backing)} (backing)
430
431
return (ChainRulesCore. NoTangent (), ∂d, ∂x)
431
432
end
0 commit comments