Skip to content

Commit e702017

Browse files
committed
switch off-support value to NaN
1 parent 1f06aa6 commit e702017

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/multivariate/dirichlet.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -403,29 +403,30 @@ end
403403

404404
function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
405405
lp = _logpdf(d, x)
406-
if !insupport(d, x)
407-
return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB)))
408-
end
409406
∂α_x = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i
410407
xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i
411408
end))
412-
∂l = - Δd.lmnB
409+
∂l = -Δd.lmnB
410+
if !insupport(d, x)
411+
∂α_x = oftype(∂α_x, NaN)
412+
end
413413
return (lp, ∂α_x + ∂l)
414414
end
415415

416416
function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
417417
y = _logpdf(d, x)
418418
function Dirichlet_logpdf_pullback(dy)
419-
if !isfinite(y)
420-
backing = (alpha = zero(d.alpha), alpha0 = ChainRulesCore.ZeroTangent(), lmnB=zero(d.lmnB))
421-
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
422-
∂x = zero(d.alpha + x)
423-
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
424-
end
425419
∂alpha = xlogy.(dy, x)
426420
∂l = -dy
427421
∂x = dy * (d.alpha .-1) ./ x
428-
backing = (alpha = ∂alpha, alpha0 = ChainRulesCore.ZeroTangent(), lmnB=∂l)
422+
∂alpha0 = 0.0
423+
if !isfinite(y)
424+
∂alpha .= NaN
425+
∂l = oftype(∂l, NaN)
426+
∂x .= NaN
427+
∂alpha0 = NaN
428+
end
429+
backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l)
429430
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
430431
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
431432
end

0 commit comments

Comments
 (0)