Skip to content

Commit f9de7b3

Browse files
committed
logpdf test
1 parent d5a293a commit f9de7b3

File tree

2 files changed

+87
-10
lines changed

2 files changed

+87
-10
lines changed

src/multivariate/dirichlet.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,6 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
381381
end
382382

383383
## Differentiation
384-
using Test
385384
function ChainRulesCore.frule((_, Δalpha), DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T}
386385
d = DT(alpha; check_args=check_args)
387386
Δalpha = ChainRulesCore.unthunk(Δalpha)
@@ -401,3 +400,42 @@ function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, al
401400
end
402401
return d, dirichlet_pullback
403402
end
403+
404+
function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T}
405+
lp = _logpdf(d, x)
406+
if !insupport(d, x)
407+
return (lp, zero(lp))
408+
end
409+
∂α = sum(Δd.alpha[i] * log(x[i]) for i in eachindex(x))
410+
∂l = - Δd.lmnB
411+
∂x = sum((d.alpha[i] - 1) * Δx[i] / x[i] for i in eachindex(x))
412+
return (lp, ∂α + ∂l + ∂x)
413+
end
414+
415+
function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T}
416+
y = _logpdf(d, x)
417+
function Dirichlet_logpdf_pullback(dy)
418+
if !isfinite(y)
419+
backing = (alpha = zero(d.alpha), alpha0 = ChainRulesCore.ZeroTangent(), lmnB=zero(d.lmnB))
420+
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
421+
∂x = zero(d.alpha + x)
422+
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
423+
end
424+
∂alpha = dy * log.(x)
425+
∂l = -dy
426+
∂x = dy * (d.alpha .-1) ./ x
427+
backing = (alpha = ∂alpha, alpha0 = ChainRulesCore.ZeroTangent(), lmnB=∂l)
428+
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
429+
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
430+
end
431+
return (y, Dirichlet_logpdf_pullback)
432+
end
433+
434+
function _logpdf(d::Dirichlet, x::AbstractVector{<:Real})
435+
if !insupport(d, x)
436+
return xlogy(one(eltype(d.alpha)), zero(eltype(x))) - d.lmnB
437+
end
438+
a = d.alpha
439+
s = sum(xlogy(αi - 1, xi) for (αi, xi) in zip(d.alpha, x))
440+
return s - d.lmnB
441+
end

test/dirichlet.jl

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,53 @@ end
129129
@test entropy(Dirichlet(ones(N))) -loggamma(N)
130130
end
131131

132-
@testset "Dirichlet differentiation" begin
133-
for n in (2, 10)
134-
alpha = rand(n)
135-
Δalpha = randn(n)
136-
d2, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
137-
ChainRulesTestUtils.test_frule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha, check_inferred=true)
138-
139-
_, dp = ChainRulesCore.rrule(Dirichlet, alpha)
140-
ChainRulesTestUtils.test_rrule(Dirichlet{Float64} ChainRulesCore.NoTangent(), alpha)
132+
@testset "Dirichlet differentiation $n" for n in (2, 10)
133+
alpha = rand(n)
134+
Δalpha = randn(n)
135+
d, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
136+
ChainRulesTestUtils.test_frule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha)
137+
_, dp = ChainRulesCore.rrule(Dirichlet, alpha)
138+
ChainRulesTestUtils.test_rrule(Dirichlet{Float64} ChainRulesCore.NoTangent(), alpha)
139+
x = rand(n)
140+
x ./= sum(x)
141+
Δx = 0.05 * rand(n)
142+
Δx .-= mean(Δx)
143+
# such that x ∈ Δ, x + Δx ∈ Δ
144+
ChainRulesTestUtils.test_frule(Distributions._logpdf ChainRulesCore.NoTangent(), d, x Δx)
145+
@testset "finite diff f/r-rule logpdf" begin
146+
for _ in 1:10
147+
x = rand(n)
148+
x ./= sum(x)
149+
Δx = 0.005 * rand(n)
150+
Δx .-= mean(Δx)
151+
if insupport(d, x + Δx) && insupport(d, x - Δx)
152+
y, pullback = ChainRulesCore.rrule(Distributions._logpdf, d, x)
153+
yf, Δy = ChainRulesCore.frule(
154+
(
155+
ChainRulesCore.NoTangent(),
156+
map(zero, ChainRulesTestUtils.rand_tangent(d)),
157+
Δx,
158+
),
159+
Distributions._logpdf,
160+
d, x,
161+
)
162+
y2 = Distributions._logpdf(d, x + Δx)
163+
y1 = Distributions._logpdf(d, x - Δx)
164+
@test isfinite(y)
165+
@test y == yf
166+
@test Δy y2 - y atol=5e-3
167+
_, ∂d, ∂x = pullback(1.0)
168+
@test y2 - y1 dot(2Δx, ∂x) atol=5e-3 rtol=1e-6
169+
# mutating alpha only to compute a new y, changing only this term and not the others in Dirichlet
170+
Δalpha = 0.03 * rand(n)
171+
Δalpha .-= mean(Δalpha)
172+
@assert all(>=(0), alpha + Δalpha)
173+
d.alpha .+= Δalpha
174+
ya = Distributions._logpdf(d, x)
175+
# resetting alpha
176+
d.alpha .-= Δalpha
177+
@test ya - y dot(Δalpha, ∂d.alpha) atol=5e-5 rtol=1e-6
178+
end
179+
end
141180
end
142181
end

0 commit comments

Comments
 (0)