Skip to content

Commit 1a3fdd9

Browse files
committed
do not assume inplace
1 parent e702017 commit 1a3fdd9

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/multivariate/dirichlet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,12 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector
419419
∂alpha = xlogy.(dy, x)
420420
∂l = -dy
421421
∂x = dy * (d.alpha .-1) ./ x
422-
∂alpha0 = 0.0
422+
∂alpha0 = sum(∂alpha)
423423
if !isfinite(y)
424-
∂alpha .= NaN
424+
∂alpha = oftype(eltype(∂alpha), NaN) * ∂alpha
425425
∂l = oftype(∂l, NaN)
426-
∂x .= NaN
427-
∂alpha0 = NaN
426+
∂x = oftype(eltype(∂x), NaN) * ∂x
427+
∂alpha0 = oftype(eltype(∂alpha), NaN)
428428
end
429429
backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l)
430430
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)

test/dirichlet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ end
141141
Δx = 0.05 * rand(n)
142142
Δx .-= mean(Δx)
143143
# such that x ∈ Δ, x + Δx ∈ Δ
144-
ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x Δx)
144+
ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x Δx, fdm=FiniteDifferences.forward_fdm(5, 1))
145145
@testset "finite diff f/r-rule logpdf" begin
146146
for _ in 1:10
147147
x = rand(n)
@@ -174,7 +174,7 @@ end
174174
ya = Distributions._logpdf(d, x)
175175
# resetting alpha
176176
d.alpha .-= Δalpha
177-
@test ya - y dot(Δalpha, ∂d.alpha) atol=5e-5 rtol=1e-6
177+
@test ya - y dot(Δalpha, ∂d.alpha) atol=1e-6 rtol=1e-6
178178
end
179179
end
180180
end

0 commit comments

Comments
 (0)