diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 8a8865e779..d77d4f5d0d 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -72,7 +72,7 @@ Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,)) length(d::Dirichlet) = length(d.alpha) mean(d::Dirichlet) = d.alpha .* inv(d.alpha0) params(d::Dirichlet) = (d.alpha,) -@inline partype(d::Dirichlet{T}) where {T<:Real} = T +@inline partype(::Dirichlet{T}) where {T<:Real} = T function var(d::Dirichlet) α0 = d.alpha0 @@ -375,3 +375,62 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64}, elogp = mean_logp(suffstats(Dirichlet, P, w)) fit_dirichlet!(elogp, α; maxiter=maxiter, tol=tol, debug=debug) end + +## Differentiation +function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} + d = DT(alpha; check_args=check_args) + ∂alpha0 = sum(Δalpha) + digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) + ∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai + Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0) + end)) + Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) + return d, Δd +end + +function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} + d = DT(alpha; check_args=check_args) + digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) + function Dirichlet_pullback(_Δd) + Δd = ChainRulesCore.unthunk(_Δd) + Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0) + return ChainRulesCore.NoTangent(), Δalpha + end + return d, Dirichlet_pullback +end + +function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) + Ω = _logpdf(d, x) + ∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi + xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi + end)) + ∂lmnB = -Δd.lmnB + ΔΩ = ∂alpha + ∂lmnB + if !isfinite(Ω) + ΔΩ = oftype(ΔΩ, NaN) + end + return Ω, ΔΩ +end + +function ChainRulesCore.rrule(::typeof(_logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet} + Ω = _logpdf(d, x) + isfinite_Ω = isfinite(Ω) + alpha = d.alpha + function _logpdf_Dirichlet_pullback(_ΔΩ) + ΔΩ = ChainRulesCore.unthunk(_ΔΩ) + ∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω) + ∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN) + Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB) + Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω) + return ChainRulesCore.NoTangent(), Δd, Δx + end + return Ω, _logpdf_Dirichlet_pullback +end +function _logpdf_Dirichlet_∂alphai(xi, ΔΩi, isfinite::Bool) + ∂alphai = xlogy.(ΔΩi, xi) + return isfinite ? ∂alphai : oftype(∂alphai, NaN) +end +function _logpdf_Dirichlet_Δxi(ΔΩi, alphai, xi, isfinite::Bool) + Δxi = ΔΩi * (alphai - 1) / xi + return isfinite ? Δxi : oftype(Δxi, NaN) +end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 1b3a18b521..78de162dca 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -2,7 +2,9 @@ using Distributions using Test, Random, LinearAlgebra - +using ChainRulesCore +using ChainRulesTestUtils +using FiniteDifferences Random.seed!(34567) @@ -127,3 +129,32 @@ end @test entropy(Dirichlet(N, 1)) ≈ -loggamma(N) @test entropy(Dirichlet(ones(N))) ≈ -loggamma(N) end + +@testset "Dirichlet: ChainRules (length=$n)" for n in (2, 10) + alpha = rand(n) + d = Dirichlet(alpha) + + @testset "constructor $T" for T in (Dirichlet, Dirichlet{Float64}) + # Avoid issues with finite differencing if values in `alpha` become negative or zero + # by using forward differencing + test_frule(T, alpha; fdm=forward_fdm(5, 1)) + test_rrule(T, alpha; fdm=forward_fdm(5, 1)) + end + + @testset "_logpdf" begin + # `x1` is in the support, `x2` isn't + x1 = rand(n) + x1 ./= sum(x1) + x2 = x1 .+ 1 + + # Use special finite differencing method that tries to avoid moving outside of the + # support by limiting the range of the points around the input that are evaluated + fdm = central_fdm(5, 1; max_range=1e-9) + + for x in (x1, x2) + # We have to adjust the tolerance since the finite differencing method is rough + test_frule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true) + test_rrule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true) + end + end +end