Skip to content

Commit d2f832b

Browse files
committed
fix tests
1 parent 0500772 commit d2f832b

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

src/multivariate/dirichlet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x:
404404
return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB)))
405405
end
406406
∂α_x = sum(eachindex(x)) do i
407-
xlogy(Δd.alpha[i], log(x[i])) + (d.alpha[i] - 1) * Δx[i] / x[i]
407+
xlogy(Δd.alpha[i], x[i]) + (d.alpha[i] - 1) * Δx[i] / x[i]
408408
end
409409
∂l = - Δd.lmnB
410410
return (lp, ∂α_x + ∂l)
@@ -419,7 +419,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector
419419
∂x = zero(d.alpha + x)
420420
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
421421
end
422-
∂alpha = xlogxy.(dy, x)
422+
∂alpha = xlogy.(dy, x)
423423
∂l = -dy
424424
∂x = dy * (d.alpha .-1) ./ x
425425
backing = (alpha = ∂alpha, alpha0 = ChainRulesCore.ZeroTangent(), lmnB=∂l)

test/dirichlet.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ end
135135
Δalpha = randn(n)
136136
d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
137137
ChainRulesTestUtils.test_frule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1))
138-
_, dp = ChainRulesCore.rrule(Dirichlet, alpha)
139138
ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha)
140139
x = rand(n)
141140
x ./= sum(x)

0 commit comments

Comments
 (0)