Skip to content

Commit 4cc7509

Browse files
committed
eltype instability
1 parent bc29c40 commit 4cc7509

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/multivariate/dirichlet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,15 +400,15 @@ end
400400
function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
401401
lp = _logpdf(d, x)
402402
if !insupport(d, x)
403-
return (lp, zero(lp))
403+
return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB)))
404404
end
405405
∂α = sum(Δd.alpha[i] * log(x[i]) for i in eachindex(x))
406406
∂l = - Δd.lmnB
407407
∂x = sum((d.alpha[i] - 1) * Δx[i] / x[i] for i in eachindex(x))
408408
return (lp, ∂α + ∂l + ∂x)
409409
end
410410

411-
function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T}
411+
function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T <: Real}
412412
y = _logpdf(d, x)
413413
function Dirichlet_logpdf_pullback(dy)
414414
if !isfinite(y)

0 commit comments

Comments
 (0)