Skip to content

Commit 9234155

Browse files
authored
Simplify implementation and tests in #1534 (#1555)
* Simplify implementation and tests * Precompute `digamma(alpha0)` * Relax type signature
1 parent 76cc96a commit 9234155

File tree

2 files changed

+64
-83
lines changed

2 files changed

+64
-83
lines changed

src/multivariate/dirichlet.jl

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -377,58 +377,60 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
377377
end
378378

379379
## Differentiation
380-
function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
380+
function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
381381
d = DT(alpha; check_args=check_args)
382-
Δalpha = ChainRulesCore.unthunk(Δalpha)
383382
∂alpha0 = sum(Δalpha)
384383
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
385-
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalpha_i, alpha_i
386-
Δalpha_i * (SpecialFunctions.digamma(alpha_i) - digamma_alpha0)
384+
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai
385+
Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0)
387386
end))
388-
backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
389-
t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing)
390-
return d, t
387+
Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
388+
return d, Δd
391389
end
392390

393391
function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
394392
d = DT(alpha; check_args=check_args)
395-
function dirichlet_pullback(d_dir)
396-
d_dir = ChainRulesCore.unthunk(d_dir)
397-
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
398-
dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
399-
return ChainRulesCore.NoTangent(), dalpha
393+
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
394+
function Dirichlet_pullback(_Δd)
395+
Δd = ChainRulesCore.unthunk(_Δd)
396+
Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
397+
return ChainRulesCore.NoTangent(), Δalpha
400398
end
401-
return d, dirichlet_pullback
399+
return d, Dirichlet_pullback
402400
end
403401

404-
function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
405-
lp = _logpdf(d, x)
406-
α_x = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i
407-
xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i
402+
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
403+
Ω = _logpdf(d, x)
404+
alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi
405+
xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi
408406
end))
409-
∂l = -Δd.lmnB
410-
if !insupport(d, x)
411-
∂α_x = oftype(∂α_x, NaN)
407+
∂lmnB = -Δd.lmnB
408+
ΔΩ = ∂alpha + ∂lmnB
409+
if !isfinite(Ω)
410+
ΔΩ = oftype(ΔΩ, NaN)
412411
end
413-
return (lp, ∂α_x + ∂l)
412+
return Ω, ΔΩ
414413
end
415414

416-
function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
417-
y = _logpdf(d, x)
418-
function Dirichlet_logpdf_pullback(dy)
419-
∂alpha = xlogy.(dy, x)
420-
∂l = -dy
421-
∂x = dy .* (d.alpha .-1) ./ x
422-
∂alpha0 = sum(∂alpha)
423-
if !isfinite(y)
424-
∂alpha = oftype(eltype(∂alpha), NaN) * ∂alpha
425-
∂l = oftype(∂l, NaN)
426-
∂x = oftype(eltype(∂x), NaN) * ∂x
427-
∂alpha0 = oftype(eltype(∂alpha), NaN)
428-
end
429-
backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l)
430-
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
431-
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
415+
function ChainRulesCore.rrule(::typeof(_logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet}
416+
Ω = _logpdf(d, x)
417+
isfinite_Ω = isfinite(Ω)
418+
alpha = d.alpha
419+
function _logpdf_Dirichlet_pullback(_ΔΩ)
420+
ΔΩ = ChainRulesCore.unthunk(_ΔΩ)
421+
∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω)
422+
∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN)
423+
Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB)
424+
Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω)
425+
return ChainRulesCore.NoTangent(), Δd, Δx
432426
end
433-
return (y, Dirichlet_logpdf_pullback)
427+
return Ω, _logpdf_Dirichlet_pullback
428+
end
429+
function _logpdf_Dirichlet_∂alphai(xi, ΔΩi, isfinite::Bool)
430+
∂alphai = xlogy.(ΔΩi, xi)
431+
return isfinite ? ∂alphai : oftype(∂alphai, NaN)
432+
end
433+
function _logpdf_Dirichlet_Δxi(ΔΩi, alphai, xi, isfinite::Bool)
434+
Δxi = ΔΩi * (alphai - 1) / xi
435+
return isfinite ? Δxi : oftype(Δxi, NaN)
434436
end

test/dirichlet.jl

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -130,52 +130,31 @@ end
130130
@test entropy(Dirichlet(ones(N))) -loggamma(N)
131131
end
132132

133-
@testset "Dirichlet differentiation $n" for n in (2, 10)
133+
@testset "Dirichlet: ChainRules (length=$n)" for n in (2, 10)
134134
alpha = rand(n)
135-
Δalpha = randn(n)
136-
d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
137-
ChainRulesTestUtils.test_frule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1))
138-
ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha; fdm=FiniteDifferences.forward_fdm(5, 1))
139-
x = rand(n)
140-
x ./= sum(x)
141-
Δx = 0.05 * rand(n)
142-
Δx .-= mean(Δx)
143-
# such that x ∈ Δ, x + Δx ∈ Δ
144-
ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x Δx, fdm=FiniteDifferences.forward_fdm(5, 1))
145-
@testset "finite diff f/r-rule logpdf" begin
146-
for _ in 1:10
147-
x = rand(n)
148-
x ./= sum(x)
149-
Δx = 0.005 * rand(n)
150-
Δx .-= mean(Δx)
151-
if insupport(d, x + Δx) && insupport(d, x - Δx)
152-
y, pullback = ChainRulesCore.rrule(Distributions._logpdf, d, x)
153-
yf, Δy = ChainRulesCore.frule(
154-
(
155-
ChainRulesCore.NoTangent(),
156-
map(zero, ChainRulesTestUtils.rand_tangent(d)),
157-
Δx,
158-
),
159-
Distributions._logpdf,
160-
d, x,
161-
)
162-
y2 = Distributions._logpdf(d, x + Δx)
163-
y1 = Distributions._logpdf(d, x - Δx)
164-
@test isfinite(y)
165-
@test y == yf
166-
@test Δy y2 - y atol=5e-3
167-
_, ∂d, ∂x = pullback(1.0)
168-
@test y2 - y1 dot(2Δx, ∂x) atol=5e-3 rtol=1e-6
169-
# mutating alpha only to compute a new y, changing only this term and not the others in Dirichlet
170-
Δalpha = 0.03 * rand(n)
171-
Δalpha .-= mean(Δalpha)
172-
@assert all(>=(0), alpha + Δalpha)
173-
d.alpha .+= Δalpha
174-
ya = Distributions._logpdf(d, x)
175-
# resetting alpha
176-
d.alpha .-= Δalpha
177-
@test ya - y dot(Δalpha, ∂d.alpha) atol=1e-6 rtol=1e-6
178-
end
135+
d = Dirichlet(alpha)
136+
137+
@testset "constructor $T" for T in (Dirichlet, Dirichlet{Float64})
138+
# Avoid issues with finite differencing if values in `alpha` become negative or zero
139+
# by using forward differencing
140+
test_frule(T, alpha; fdm=forward_fdm(5, 1))
141+
test_rrule(T, alpha; fdm=forward_fdm(5, 1))
142+
end
143+
144+
@testset "_logpdf" begin
145+
# `x1` is in the support, `x2` isn't
146+
x1 = rand(n)
147+
x1 ./= sum(x1)
148+
x2 = x1 .+ 1
149+
150+
# Use special finite differencing method that tries to avoid moving outside of the
151+
# support by limiting the range of the points around the input that are evaluated
152+
fdm = central_fdm(5, 1; max_range=1e-9)
153+
154+
for x in (x1, x2)
155+
# We have to adjust the tolerance since the finite differencing method is rough
156+
test_frule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true)
157+
test_rrule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true)
179158
end
180159
end
181160
end

0 commit comments

Comments
 (0)