From ae0a0ad2107734fe489b5768f4de207603bfa6d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sat, 16 Apr 2022 14:41:12 +0200 Subject: [PATCH 01/26] constructor frule --- Project.toml | 2 ++ src/multivariate/dirichlet.jl | 25 +++++++++++++++++++------ test/dirichlet.jl | 12 +++++++++++- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 20f8257815..a59f354163 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,9 @@ authors = ["JuliaStats"] version = "0.25.53" [deps] +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 8a8865e779..2975099fd3 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -26,10 +26,12 @@ struct Dirichlet{T<:Real,Ts<:AbstractVector{T},S<:Real} <: ContinuousMultivariat lmnB::S function Dirichlet{T}(alpha::AbstractVector{T}; check_args::Bool=true) where T - @check_args( - Dirichlet, - (alpha, all(x -> x > zero(x), alpha), "alpha must be a positive vector."), - ) + if check_args + @check_args( + Dirichlet, + (alpha, all(x -> x > zero(x), alpha), "alpha must be a positive vector."), + ) + end alpha0 = sum(alpha) lmnB = sum(loggamma, alpha) - loggamma(alpha0) new{T,typeof(alpha),typeof(lmnB)}(alpha, alpha0, lmnB) @@ -40,7 +42,9 @@ function Dirichlet(alpha::AbstractVector{T}; check_args::Bool=true) where {T<:Re Dirichlet{T}(alpha; check_args=check_args) end function Dirichlet(d::Integer, alpha::Real; check_args::Bool=true) - @check_args Dirichlet (d, d > zero(d)) (alpha, alpha > zero(alpha)) + if check_args + @check_args Dirichlet (d, d > zero(d)) (alpha, alpha > zero(alpha)) + end return Dirichlet{typeof(alpha)}(Fill(alpha, d); check_args=false) end @@ -72,7 +76,7 @@ Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,)) length(d::Dirichlet) = length(d.alpha) mean(d::Dirichlet) = d.alpha .* inv(d.alpha0) params(d::Dirichlet) = (d.alpha,) -@inline partype(d::Dirichlet{T}) where {T<:Real} = T +@inline partype(::Dirichlet{T}) where {T<:Real} = T function var(d::Dirichlet) α0 = d.alpha0 @@ -375,3 +379,12 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64}, elogp = mean_logp(suffstats(Dirichlet, P, w)) fit_dirichlet!(elogp, α; maxiter=maxiter, tol=tol, debug=debug) end + +## Differentiation +function ChainRulesCore.frule((_, Δalpha), DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T} + d = DT(alpha; check_args=check_args) + Δalpha = ChainRulesCore.unthunk(Δalpha) + ∂alpha0 = sum(Δalpha) + ∂lmnB = (sum(SpecialFunctions.digamma(αi) for αi in alpha) - SpecialFunctions.digamma(d.alpha0)) * Δalpha + return d, ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) +end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 1b3a18b521..4e6984e4c5 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -2,7 +2,8 @@ using Distributions using Test, Random, LinearAlgebra - +using ChainRulesCore +using ChainRulesTestUtils Random.seed!(34567) @@ -127,3 +128,12 @@ end @test entropy(Dirichlet(N, 1)) ≈ -loggamma(N) @test entropy(Dirichlet(ones(N))) ≈ -loggamma(N) end + +@testset "Dirichlet differentiation" begin + for n in (2, 10) + alpha = rand(n) + Δalpha = randn(n) + d2, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) + ChainRulesTestUtils.test_frule(Dirichlet{Float64}, alpha ⊢ Δalpha) + end +end From d407bf23119459a5aed2c6cbf4a335854eb687d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sat, 16 Apr 2022 15:52:07 +0200 Subject: [PATCH 02/26] frule tested --- src/multivariate/dirichlet.jl | 17 +++++++++++++++-- test/dirichlet.jl | 5 ++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 2975099fd3..c20195e3fb 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -381,10 +381,23 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64}, end ## Differentiation +using Test function ChainRulesCore.frule((_, Δalpha), DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T} d = DT(alpha; check_args=check_args) Δalpha = ChainRulesCore.unthunk(Δalpha) ∂alpha0 = sum(Δalpha) - ∂lmnB = (sum(SpecialFunctions.digamma(αi) for αi in alpha) - SpecialFunctions.digamma(d.alpha0)) * Δalpha - return d, ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) + ∂lmnB::typeof(∂alpha0) = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - SpecialFunctions.digamma(d.alpha0)) for i in eachindex(alpha)) + backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) + t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing) + return d, t +end + +function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T} + d = DT(alpha; check_args=check_args) + function dirichlet_pullback(d_dir) + d_dir = ChainRulesCore.unthunk(alpha) + @info typeof(d_dir) + return (ChainRulesCore.NoTangent(), d_dir.alpha) + end + return d, dirichlet_pullback end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 4e6984e4c5..9efb90cc1b 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -134,6 +134,9 @@ end alpha = rand(n) Δalpha = randn(n) d2, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) - ChainRulesTestUtils.test_frule(Dirichlet{Float64}, alpha ⊢ Δalpha) + ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha, check_inferred=false) + + _, dp = ChainRulesCore.rrule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha) + ChainRulesTestUtils.test_rrule(Dirichlet{Float64} ⊢ ChainRulesCore.NoTangent(), alpha) end end From d5a293aa64a56d607277b5339a8560f44d7f687c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sat, 16 Apr 2022 18:32:04 +0200 Subject: [PATCH 03/26] rrule tests --- src/multivariate/dirichlet.jl | 6 +++--- test/dirichlet.jl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index c20195e3fb..949f891450 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -395,9 +395,9 @@ end function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T} d = DT(alpha; check_args=check_args) function dirichlet_pullback(d_dir) - d_dir = ChainRulesCore.unthunk(alpha) - @info typeof(d_dir) - return (ChainRulesCore.NoTangent(), d_dir.alpha) + d_dir = ChainRulesCore.unthunk(d_dir) + ∂l = d_dir.lmnB * (SpecialFunctions.digamma.(alpha) .- SpecialFunctions.digamma.(d.alpha0)) + return (ChainRulesCore.NoTangent(), d_dir.alpha .+ d_dir.alpha0 .+ ∂l) end return d, dirichlet_pullback end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 9efb90cc1b..5409219227 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -134,9 +134,9 @@ end alpha = rand(n) Δalpha = randn(n) d2, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) - ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha, check_inferred=false) + ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha, check_inferred=true) - _, dp = ChainRulesCore.rrule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha) + _, dp = ChainRulesCore.rrule(Dirichlet, alpha) ChainRulesTestUtils.test_rrule(Dirichlet{Float64} ⊢ ChainRulesCore.NoTangent(), alpha) end end From f9de7b36da780627d556a95f408c5d5333afe107 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 17 Apr 2022 19:24:29 +0200 Subject: [PATCH 04/26] logpdf test --- src/multivariate/dirichlet.jl | 40 +++++++++++++++++++++++- test/dirichlet.jl | 57 +++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 949f891450..c05e7ac642 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -381,7 +381,6 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64}, end ## Differentiation -using Test function ChainRulesCore.frule((_, Δalpha), DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T} d = DT(alpha; check_args=check_args) Δalpha = ChainRulesCore.unthunk(Δalpha) @@ -401,3 +400,42 @@ function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, al end return d, dirichlet_pullback end + +function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T} + lp = _logpdf(d, x) + if !insupport(d, x) + return (lp, zero(lp)) + end + ∂α = sum(Δd.alpha[i] * log(x[i]) for i in eachindex(x)) + ∂l = - Δd.lmnB + ∂x = sum((d.alpha[i] - 1) * Δx[i] / x[i] for i in eachindex(x)) + return (lp, ∂α + ∂l + ∂x) +end + +function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T} + y = _logpdf(d, x) + function Dirichlet_logpdf_pullback(dy) + if !isfinite(y) + backing = (alpha = zero(d.alpha), alpha0 = ChainRulesCore.ZeroTangent(), lmnB=zero(d.lmnB)) + ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) + ∂x = zero(d.alpha + x) + return (ChainRulesCore.NoTangent(), ∂d, ∂x) + end + ∂alpha = dy * log.(x) + ∂l = -dy + ∂x = dy * (d.alpha .-1) ./ x + backing = (alpha = ∂alpha, alpha0 = ChainRulesCore.ZeroTangent(), lmnB=∂l) + ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) + return (ChainRulesCore.NoTangent(), ∂d, ∂x) + end + return (y, Dirichlet_logpdf_pullback) +end + +function _logpdf(d::Dirichlet, x::AbstractVector{<:Real}) + if !insupport(d, x) + return xlogy(one(eltype(d.alpha)), zero(eltype(x))) - d.lmnB + end + a = d.alpha + s = sum(xlogy(αi - 1, xi) for (αi, xi) in zip(d.alpha, x)) + return s - d.lmnB +end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 5409219227..e2cadb48da 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -129,14 +129,53 @@ end @test entropy(Dirichlet(ones(N))) ≈ -loggamma(N) end -@testset "Dirichlet differentiation" begin - for n in (2, 10) - alpha = rand(n) - Δalpha = randn(n) - d2, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) - ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha, check_inferred=true) - - _, dp = ChainRulesCore.rrule(Dirichlet, alpha) - ChainRulesTestUtils.test_rrule(Dirichlet{Float64} ⊢ ChainRulesCore.NoTangent(), alpha) +@testset "Dirichlet differentiation $n" for n in (2, 10) + alpha = rand(n) + Δalpha = randn(n) + d, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) + ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha) + _, dp = ChainRulesCore.rrule(Dirichlet, alpha) + ChainRulesTestUtils.test_rrule(Dirichlet{Float64} ⊢ ChainRulesCore.NoTangent(), alpha) + x = rand(n) + x ./= sum(x) + Δx = 0.05 * rand(n) + Δx .-= mean(Δx) + # such that x ∈ Δ, x + Δx ∈ Δ + ChainRulesTestUtils.test_frule(Distributions._logpdf ⊢ ChainRulesCore.NoTangent(), d, x ⊢ Δx) + @testset "finite diff f/r-rule logpdf" begin + for _ in 1:10 + x = rand(n) + x ./= sum(x) + Δx = 0.005 * rand(n) + Δx .-= mean(Δx) + if insupport(d, x + Δx) && insupport(d, x - Δx) + y, pullback = ChainRulesCore.rrule(Distributions._logpdf, d, x) + yf, Δy = ChainRulesCore.frule( + ( + ChainRulesCore.NoTangent(), + map(zero, ChainRulesTestUtils.rand_tangent(d)), + Δx, + ), + Distributions._logpdf, + d, x, + ) + y2 = Distributions._logpdf(d, x + Δx) + y1 = Distributions._logpdf(d, x - Δx) + @test isfinite(y) + @test y == yf + @test Δy ≈ y2 - y atol=5e-3 + _, ∂d, ∂x = pullback(1.0) + @test y2 - y1 ≈ dot(2Δx, ∂x) atol=5e-3 rtol=1e-6 + # mutating alpha only to compute a new y, changing only this term and not the others in Dirichlet + Δalpha = 0.03 * rand(n) + Δalpha .-= mean(Δalpha) + @assert all(>=(0), alpha + Δalpha) + d.alpha .+= Δalpha + ya = Distributions._logpdf(d, x) + # resetting alpha + d.alpha .-= Δalpha + @test ya - y ≈ dot(Δalpha, ∂d.alpha) atol=5e-5 rtol=1e-6 + end + end end end From 3723789b2cb22913f5c9011a60702476715053d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 17 Apr 2022 20:07:11 +0200 Subject: [PATCH 05/26] signature for conflict --- src/multivariate/dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index c05e7ac642..00713415e8 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -381,7 +381,7 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64}, end ## Differentiation -function ChainRulesCore.frule((_, Δalpha), DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T} +function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T} d = DT(alpha; check_args=check_args) Δalpha = ChainRulesCore.unthunk(Δalpha) ∂alpha0 = sum(Δalpha) From 5e32f04a49c8bb705c53638b23c6a68f653a2dd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 24 Apr 2022 10:21:54 +0200 Subject: [PATCH 06/26] TestUtils out of Project --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index a59f354163..a1a7ba2466 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.25.53" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 0dde72f7496c4fd4e937ec755cc14f6194d5f116 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 24 Apr 2022 10:25:32 +0200 Subject: [PATCH 07/26] ChainRules itself not needed (yet?) --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index a1a7ba2466..20f8257815 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ authors = ["JuliaStats"] version = "0.25.53" [deps] -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" From 13487925d8514b929e1b739b87eec54bb14e9479 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 24 Apr 2022 10:39:46 +0200 Subject: [PATCH 08/26] remove checkarg --- src/multivariate/dirichlet.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 00713415e8..1ab17e8d32 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -26,12 +26,10 @@ struct Dirichlet{T<:Real,Ts<:AbstractVector{T},S<:Real} <: ContinuousMultivariat lmnB::S function Dirichlet{T}(alpha::AbstractVector{T}; check_args::Bool=true) where T - if check_args - @check_args( - Dirichlet, - (alpha, all(x -> x > zero(x), alpha), "alpha must be a positive vector."), - ) - end + @check_args( + Dirichlet, + (alpha, all(x -> x > zero(x), alpha), "alpha must be a positive vector."), + ) alpha0 = sum(alpha) lmnB = sum(loggamma, alpha) - loggamma(alpha0) new{T,typeof(alpha),typeof(lmnB)}(alpha, alpha0, lmnB) @@ -42,9 +40,7 @@ function Dirichlet(alpha::AbstractVector{T}; check_args::Bool=true) where {T<:Re Dirichlet{T}(alpha; check_args=check_args) end function Dirichlet(d::Integer, alpha::Real; check_args::Bool=true) - if check_args - @check_args Dirichlet (d, d > zero(d)) (alpha, alpha > zero(alpha)) - end + @check_args Dirichlet (d, d > zero(d)) (alpha, alpha > zero(alpha)) return Dirichlet{typeof(alpha)}(Fill(alpha, d); check_args=false) end From 90455c8a54d6811b8666a70bddcc27a82d9192d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 08:55:26 +0200 Subject: [PATCH 09/26] Update src/multivariate/dirichlet.jl Co-authored-by: David Widmann --- src/multivariate/dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 1ab17e8d32..989c2b25ba 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -397,7 +397,7 @@ function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, al return d, dirichlet_pullback end -function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T} +function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) lp = _logpdf(d, x) if !insupport(d, x) return (lp, zero(lp)) From 25a41f34cb71b6cc1ae1f7baa5e6825c4e6ca054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 08:56:24 +0200 Subject: [PATCH 10/26] Update test/dirichlet.jl Co-authored-by: David Widmann --- test/dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dirichlet.jl b/test/dirichlet.jl index e2cadb48da..87a7ad48fd 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -135,7 +135,7 @@ end d, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha) _, dp = ChainRulesCore.rrule(Dirichlet, alpha) - ChainRulesTestUtils.test_rrule(Dirichlet{Float64} ⊢ ChainRulesCore.NoTangent(), alpha) + ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha) x = rand(n) x ./= sum(x) Δx = 0.05 * rand(n) From 89a9346d3d76d9369513ef69030d830ccc9bc16c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 09:01:24 +0200 Subject: [PATCH 11/26] Update test/dirichlet.jl Co-authored-by: David Widmann --- test/dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 87a7ad48fd..68ab42e6ce 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -133,7 +133,7 @@ end alpha = rand(n) Δalpha = randn(n) d, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) - ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha) + ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1)) _, dp = ChainRulesCore.rrule(Dirichlet, alpha) ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha) x = rand(n) From 96883e8f1b7e3346496f6d60b2e1b7a4d7c0ab5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 09:01:35 +0200 Subject: [PATCH 12/26] Update test/dirichlet.jl Co-authored-by: David Widmann --- test/dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 68ab42e6ce..d6cacea921 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -141,7 +141,7 @@ end Δx = 0.05 * rand(n) Δx .-= mean(Δx) # such that x ∈ Δ, x + Δx ∈ Δ - ChainRulesTestUtils.test_frule(Distributions._logpdf ⊢ ChainRulesCore.NoTangent(), d, x ⊢ Δx) + ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x ⊢ Δx) @testset "finite diff f/r-rule logpdf" begin for _ in 1:10 x = rand(n) From ab011227e4bda9cb8e1e17aab97c1f523eb20151 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 19:37:37 +0200 Subject: [PATCH 13/26] Update src/multivariate/dirichlet.jl Co-authored-by: David Widmann --- src/multivariate/dirichlet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 989c2b25ba..054a8c41d9 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -391,8 +391,8 @@ function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, al d = DT(alpha; check_args=check_args) function dirichlet_pullback(d_dir) d_dir = ChainRulesCore.unthunk(d_dir) - ∂l = d_dir.lmnB * (SpecialFunctions.digamma.(alpha) .- SpecialFunctions.digamma.(d.alpha0)) - return (ChainRulesCore.NoTangent(), d_dir.alpha .+ d_dir.alpha0 .+ ∂l) + dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- SpecialFunctions.digamma.(d.alpha0)) + return ChainRulesCore.NoTangent(), dalpha end return d, dirichlet_pullback end From 1d79fec94082465cbacb5de68d0874010d734fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 19:37:50 +0200 Subject: [PATCH 14/26] conflict --- src/multivariate/dirichlet.jl | 15 +++------------ test/dirichlet.jl | 3 ++- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 989c2b25ba..452e3e2fe4 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -377,17 +377,17 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64}, end ## Differentiation -function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T} +function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} d = DT(alpha; check_args=check_args) Δalpha = ChainRulesCore.unthunk(Δalpha) ∂alpha0 = sum(Δalpha) - ∂lmnB::typeof(∂alpha0) = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - SpecialFunctions.digamma(d.alpha0)) for i in eachindex(alpha)) + ∂lmnB = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - SpecialFunctions.digamma(d.alpha0)) for i in eachindex(alpha)) backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing) return d, t end -function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T} +function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} d = DT(alpha; check_args=check_args) function dirichlet_pullback(d_dir) d_dir = ChainRulesCore.unthunk(d_dir) @@ -426,12 +426,3 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector end return (y, Dirichlet_logpdf_pullback) end - -function _logpdf(d::Dirichlet, x::AbstractVector{<:Real}) - if !insupport(d, x) - return xlogy(one(eltype(d.alpha)), zero(eltype(x))) - d.lmnB - end - a = d.alpha - s = sum(xlogy(αi - 1, xi) for (αi, xi) in zip(d.alpha, x)) - return s - d.lmnB -end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index d6cacea921..6483da0969 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -4,6 +4,7 @@ using Distributions using Test, Random, LinearAlgebra using ChainRulesCore using ChainRulesTestUtils +using FiniteDifferences Random.seed!(34567) @@ -132,7 +133,7 @@ end @testset "Dirichlet differentiation $n" for n in (2, 10) alpha = rand(n) Δalpha = randn(n) - d, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) + d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1)) _, dp = ChainRulesCore.rrule(Dirichlet, alpha) ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha) From 4cc75095d5d7772ef83ac392d6e480f162890059 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 19:43:04 +0200 Subject: [PATCH 15/26] eltype instability --- src/multivariate/dirichlet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 771ecd8e74..789df687c4 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -400,7 +400,7 @@ end function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) lp = _logpdf(d, x) if !insupport(d, x) - return (lp, zero(lp)) + return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB))) end ∂α = sum(Δd.alpha[i] * log(x[i]) for i in eachindex(x)) ∂l = - Δd.lmnB @@ -408,7 +408,7 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x: return (lp, ∂α + ∂l + ∂x) end -function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T} +function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T <: Real} y = _logpdf(d, x) function Dirichlet_logpdf_pullback(dy) if !isfinite(y) From 0500772bfd00442b7f6f4ea7c9c84665b7af7f7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 19:58:01 +0200 Subject: [PATCH 16/26] single loop --- src/multivariate/dirichlet.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 789df687c4..5bea25f367 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -381,7 +381,8 @@ function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T} d = DT(alpha; check_args=check_args) Δalpha = ChainRulesCore.unthunk(Δalpha) ∂alpha0 = sum(Δalpha) - ∂lmnB = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - SpecialFunctions.digamma(d.alpha0)) for i in eachindex(alpha)) + digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) + ∂lmnB = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - digamma_alpha0) for i in eachindex(alpha)) backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing) return d, t @@ -402,13 +403,14 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x: if !insupport(d, x) return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB))) end - ∂α = sum(Δd.alpha[i] * log(x[i]) for i in eachindex(x)) + ∂α_x = sum(eachindex(x)) do i + xlogy(Δd.alpha[i], log(x[i])) + (d.alpha[i] - 1) * Δx[i] / x[i] + end ∂l = - Δd.lmnB - ∂x = sum((d.alpha[i] - 1) * Δx[i] / x[i] for i in eachindex(x)) - return (lp, ∂α + ∂l + ∂x) + return (lp, ∂α_x + ∂l) end -function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{T}) where {T <: Real} +function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) y = _logpdf(d, x) function Dirichlet_logpdf_pullback(dy) if !isfinite(y) @@ -417,7 +419,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector ∂x = zero(d.alpha + x) return (ChainRulesCore.NoTangent(), ∂d, ∂x) end - ∂alpha = dy * log.(x) + ∂alpha = xlogxy.(dy, x) ∂l = -dy ∂x = dy * (d.alpha .-1) ./ x backing = (alpha = ∂alpha, alpha0 = ChainRulesCore.ZeroTangent(), lmnB=∂l) From d2f832ba660aa6ad4cc99e9a88b58c021f060137 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 20:03:15 +0200 Subject: [PATCH 17/26] fix tests --- src/multivariate/dirichlet.jl | 4 ++-- test/dirichlet.jl | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 5bea25f367..51bdb2ac0b 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -404,7 +404,7 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x: return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB))) end ∂α_x = sum(eachindex(x)) do i - xlogy(Δd.alpha[i], log(x[i])) + (d.alpha[i] - 1) * Δx[i] / x[i] + xlogy(Δd.alpha[i], x[i]) + (d.alpha[i] - 1) * Δx[i] / x[i] end ∂l = - Δd.lmnB return (lp, ∂α_x + ∂l) @@ -419,7 +419,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector ∂x = zero(d.alpha + x) return (ChainRulesCore.NoTangent(), ∂d, ∂x) end - ∂alpha = xlogxy.(dy, x) + ∂alpha = xlogy.(dy, x) ∂l = -dy ∂x = dy * (d.alpha .-1) ./ x backing = (alpha = ∂alpha, alpha0 = ChainRulesCore.ZeroTangent(), lmnB=∂l) diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 6483da0969..ec11fa7dcf 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -135,7 +135,6 @@ end Δalpha = randn(n) d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1)) - _, dp = ChainRulesCore.rrule(Dirichlet, alpha) ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha) x = rand(n) x ./= sum(x) From 77ccee6e1e106381ed3297535dc42d2c7f15de39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 22:08:13 +0200 Subject: [PATCH 18/26] forward finite diff --- test/dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dirichlet.jl b/test/dirichlet.jl index ec11fa7dcf..04923b9414 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -135,7 +135,7 @@ end Δalpha = randn(n) d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1)) - ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha) + ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha; fdm=FiniteDifferences.forward_fdm(5, 1)) x = rand(n) x ./= sum(x) Δx = 0.05 * rand(n) From feafacd1a9ef1b5b151ea7206d6dc1681c0d2e1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 22:32:26 +0200 Subject: [PATCH 19/26] switch to broadcast --- src/multivariate/dirichlet.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 51bdb2ac0b..df30840807 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -382,7 +382,9 @@ function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T} Δalpha = ChainRulesCore.unthunk(Δalpha) ∂alpha0 = sum(Δalpha) digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) - ∂lmnB = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - digamma_alpha0) for i in eachindex(alpha)) + ∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha))) do Δalpha_i, alpha_i + Δalpha_i * (SpecialFunctions.digamma(alpha_i) - digamma_alpha0) + end backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing) return d, t @@ -392,7 +394,8 @@ function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args:: d = DT(alpha; check_args=check_args) function dirichlet_pullback(d_dir) d_dir = ChainRulesCore.unthunk(d_dir) - dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- SpecialFunctions.digamma.(d.alpha0)) + digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) + dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0) return ChainRulesCore.NoTangent(), dalpha end return d, dirichlet_pullback @@ -403,8 +406,8 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x: if !insupport(d, x) return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB))) end - ∂α_x = sum(eachindex(x)) do i - xlogy(Δd.alpha[i], x[i]) + (d.alpha[i] - 1) * Δx[i] / x[i] + ∂α_x = sum(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x)) do Δalpha_i, Δx_i, alpha_i, x_i + xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i end ∂l = - Δd.lmnB return (lp, ∂α_x + ∂l) From 1f06aa6a0277fa626bbfd8180f6f74fb00840b21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 22:55:12 +0200 Subject: [PATCH 20/26] fix broadcast --- src/multivariate/dirichlet.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index df30840807..143b13d505 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -382,9 +382,9 @@ function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T} Δalpha = ChainRulesCore.unthunk(Δalpha) ∂alpha0 = sum(Δalpha) digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) - ∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha))) do Δalpha_i, alpha_i + ∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalpha_i, alpha_i Δalpha_i * (SpecialFunctions.digamma(alpha_i) - digamma_alpha0) - end + end)) backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing) return d, t @@ -406,9 +406,9 @@ function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x: if !insupport(d, x) return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB))) end - ∂α_x = sum(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x)) do Δalpha_i, Δx_i, alpha_i, x_i + ∂α_x = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i - end + end)) ∂l = - Δd.lmnB return (lp, ∂α_x + ∂l) end From e70201717ca7a931e00479bdce33da457b0bb87c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 23:11:37 +0200 Subject: [PATCH 21/26] switch off-support value to NaN --- src/multivariate/dirichlet.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 143b13d505..4839746a53 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -403,29 +403,30 @@ end function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) lp = _logpdf(d, x) - if !insupport(d, x) - return (lp, zero(lp) + zero(eltype(Δx)) + zero(eltype(Δd.alpha)) + zero(eltype(Δd.lmnB))) - end ∂α_x = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i end)) - ∂l = - Δd.lmnB + ∂l = -Δd.lmnB + if !insupport(d, x) + ∂α_x = oftype(∂α_x, NaN) + end return (lp, ∂α_x + ∂l) end function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) y = _logpdf(d, x) function Dirichlet_logpdf_pullback(dy) - if !isfinite(y) - backing = (alpha = zero(d.alpha), alpha0 = ChainRulesCore.ZeroTangent(), lmnB=zero(d.lmnB)) - ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) - ∂x = zero(d.alpha + x) - return (ChainRulesCore.NoTangent(), ∂d, ∂x) - end ∂alpha = xlogy.(dy, x) ∂l = -dy ∂x = dy * (d.alpha .-1) ./ x - backing = (alpha = ∂alpha, alpha0 = ChainRulesCore.ZeroTangent(), lmnB=∂l) + ∂alpha0 = 0.0 + if !isfinite(y) + ∂alpha .= NaN + ∂l = oftype(∂l, NaN) + ∂x .= NaN + ∂alpha0 = NaN + end + backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l) ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) return (ChainRulesCore.NoTangent(), ∂d, ∂x) end From 475a9341800ae907e21650543a442d3579618d0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Fri, 29 Apr 2022 09:37:58 +0200 Subject: [PATCH 22/26] Update src/multivariate/dirichlet.jl Co-authored-by: David Widmann --- src/multivariate/dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 4839746a53..1084b719e9 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -418,7 +418,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector function Dirichlet_logpdf_pullback(dy) ∂alpha = xlogy.(dy, x) ∂l = -dy - ∂x = dy * (d.alpha .-1) ./ x + ∂x = dy .* (d.alpha .-1) ./ x ∂alpha0 = 0.0 if !isfinite(y) ∂alpha .= NaN From 7515e86cab0125504b2ae1074e9159bf6d202fe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Fri, 29 Apr 2022 09:38:44 +0200 Subject: [PATCH 23/26] Update src/multivariate/dirichlet.jl Co-authored-by: David Widmann --- src/multivariate/dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 1084b719e9..e9d80a5823 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -424,7 +424,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector ∂alpha .= NaN ∂l = oftype(∂l, NaN) ∂x .= NaN - ∂alpha0 = NaN + ∂alpha0 = oftype(∂alpha0, NaN) end backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l) ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) From 1a3fdd983e6b4a104aecfaace44492f6ea8567c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 1 May 2022 09:47:29 +0200 Subject: [PATCH 24/26] do not assume inplace --- src/multivariate/dirichlet.jl | 8 ++++---- test/dirichlet.jl | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 4839746a53..0a32832102 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -419,12 +419,12 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector ∂alpha = xlogy.(dy, x) ∂l = -dy ∂x = dy * (d.alpha .-1) ./ x - ∂alpha0 = 0.0 + ∂alpha0 = sum(∂alpha) if !isfinite(y) - ∂alpha .= NaN + ∂alpha = oftype(eltype(∂alpha), NaN) * ∂alpha ∂l = oftype(∂l, NaN) - ∂x .= NaN - ∂alpha0 = NaN + ∂x = oftype(eltype(∂x), NaN) * ∂x + ∂alpha0 = oftype(eltype(∂alpha), NaN) end backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l) ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 04923b9414..98825a7ba1 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -141,7 +141,7 @@ end Δx = 0.05 * rand(n) Δx .-= mean(Δx) # such that x ∈ Δ, x + Δx ∈ Δ - ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x ⊢ Δx) + ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x ⊢ Δx, fdm=FiniteDifferences.forward_fdm(5, 1)) @testset "finite diff f/r-rule logpdf" begin for _ in 1:10 x = rand(n) @@ -174,7 +174,7 @@ end ya = Distributions._logpdf(d, x) # resetting alpha d.alpha .-= Δalpha - @test ya - y ≈ dot(Δalpha, ∂d.alpha) atol=5e-5 rtol=1e-6 + @test ya - y ≈ dot(Δalpha, ∂d.alpha) atol=1e-6 rtol=1e-6 end end end From cb4f07ebec1c3d971298930aab166a401473b6d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 22 May 2022 21:19:00 -0400 Subject: [PATCH 25/26] fixed temp --- src/multivariate/dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 77ff5dd0b0..8485ed63c2 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -417,7 +417,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector y = _logpdf(d, x) function Dirichlet_logpdf_pullback(dy) ∂alpha = xlogy.(dy, x) - ∂l = -dy + ∂l = -dy # + (SpecialFunctions.digamma(alpha_i) - digamma_alpha0) ∂x = dy .* (d.alpha .-1) ./ x ∂alpha0 = sum(∂alpha) if !isfinite(y) From 92341557edb0bba21e298f6450def7136243200a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 24 May 2022 15:28:46 +0200 Subject: [PATCH 26/26] Simplify implementation and tests in #1534 (#1555) * Simplify implementation and tests * Precompute `digamma(alpha0)` * Relax type signature --- src/multivariate/dirichlet.jl | 78 ++++++++++++++++++----------------- test/dirichlet.jl | 69 +++++++++++-------------------- 2 files changed, 64 insertions(+), 83 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 77ff5dd0b0..d77d4f5d0d 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -377,58 +377,60 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64}, end ## Differentiation -function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} +function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} d = DT(alpha; check_args=check_args) - Δalpha = ChainRulesCore.unthunk(Δalpha) ∂alpha0 = sum(Δalpha) digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) - ∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalpha_i, alpha_i - Δalpha_i * (SpecialFunctions.digamma(alpha_i) - digamma_alpha0) + ∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai + Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0) end)) - backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) - t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing) - return d, t + Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) + return d, Δd end function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} d = DT(alpha; check_args=check_args) - function dirichlet_pullback(d_dir) - d_dir = ChainRulesCore.unthunk(d_dir) - digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) - dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0) - return ChainRulesCore.NoTangent(), dalpha + digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) + function Dirichlet_pullback(_Δd) + Δd = ChainRulesCore.unthunk(_Δd) + Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0) + return ChainRulesCore.NoTangent(), Δalpha end - return d, dirichlet_pullback + return d, Dirichlet_pullback end -function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) - lp = _logpdf(d, x) - ∂α_x = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i - xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i +function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) + Ω = _logpdf(d, x) + ∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi + xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi end)) - ∂l = -Δd.lmnB - if !insupport(d, x) - ∂α_x = oftype(∂α_x, NaN) + ∂lmnB = -Δd.lmnB + ΔΩ = ∂alpha + ∂lmnB + if !isfinite(Ω) + ΔΩ = oftype(ΔΩ, NaN) end - return (lp, ∂α_x + ∂l) + return Ω, ΔΩ end -function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) - y = _logpdf(d, x) - function Dirichlet_logpdf_pullback(dy) - ∂alpha = xlogy.(dy, x) - ∂l = -dy - ∂x = dy .* (d.alpha .-1) ./ x - ∂alpha0 = sum(∂alpha) - if !isfinite(y) - ∂alpha = oftype(eltype(∂alpha), NaN) * ∂alpha - ∂l = oftype(∂l, NaN) - ∂x = oftype(eltype(∂x), NaN) * ∂x - ∂alpha0 = oftype(eltype(∂alpha), NaN) - end - backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l) - ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) - return (ChainRulesCore.NoTangent(), ∂d, ∂x) +function ChainRulesCore.rrule(::typeof(_logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet} + Ω = _logpdf(d, x) + isfinite_Ω = isfinite(Ω) + alpha = d.alpha + function _logpdf_Dirichlet_pullback(_ΔΩ) + ΔΩ = ChainRulesCore.unthunk(_ΔΩ) + ∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω) + ∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN) + Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB) + Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω) + return ChainRulesCore.NoTangent(), Δd, Δx end - return (y, Dirichlet_logpdf_pullback) + return Ω, _logpdf_Dirichlet_pullback +end +function _logpdf_Dirichlet_∂alphai(xi, ΔΩi, isfinite::Bool) + ∂alphai = xlogy.(ΔΩi, xi) + return isfinite ? ∂alphai : oftype(∂alphai, NaN) +end +function _logpdf_Dirichlet_Δxi(ΔΩi, alphai, xi, isfinite::Bool) + Δxi = ΔΩi * (alphai - 1) / xi + return isfinite ? Δxi : oftype(Δxi, NaN) end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 98825a7ba1..78de162dca 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -130,52 +130,31 @@ end @test entropy(Dirichlet(ones(N))) ≈ -loggamma(N) end -@testset "Dirichlet differentiation $n" for n in (2, 10) +@testset "Dirichlet: ChainRules (length=$n)" for n in (2, 10) alpha = rand(n) - Δalpha = randn(n) - d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) - ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1)) - ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha; fdm=FiniteDifferences.forward_fdm(5, 1)) - x = rand(n) - x ./= sum(x) - Δx = 0.05 * rand(n) - Δx .-= mean(Δx) - # such that x ∈ Δ, x + Δx ∈ Δ - ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x ⊢ Δx, fdm=FiniteDifferences.forward_fdm(5, 1)) - @testset "finite diff f/r-rule logpdf" begin - for _ in 1:10 - x = rand(n) - x ./= sum(x) - Δx = 0.005 * rand(n) - Δx .-= mean(Δx) - if insupport(d, x + Δx) && insupport(d, x - Δx) - y, pullback = ChainRulesCore.rrule(Distributions._logpdf, d, x) - yf, Δy = ChainRulesCore.frule( - ( - ChainRulesCore.NoTangent(), - map(zero, ChainRulesTestUtils.rand_tangent(d)), - Δx, - ), - Distributions._logpdf, - d, x, - ) - y2 = Distributions._logpdf(d, x + Δx) - y1 = Distributions._logpdf(d, x - Δx) - @test isfinite(y) - @test y == yf - @test Δy ≈ y2 - y atol=5e-3 - _, ∂d, ∂x = pullback(1.0) - @test y2 - y1 ≈ dot(2Δx, ∂x) atol=5e-3 rtol=1e-6 - # mutating alpha only to compute a new y, changing only this term and not the others in Dirichlet - Δalpha = 0.03 * rand(n) - Δalpha .-= mean(Δalpha) - @assert all(>=(0), alpha + Δalpha) - d.alpha .+= Δalpha - ya = Distributions._logpdf(d, x) - # resetting alpha - d.alpha .-= Δalpha - @test ya - y ≈ dot(Δalpha, ∂d.alpha) atol=1e-6 rtol=1e-6 - end + d = Dirichlet(alpha) + + @testset "constructor $T" for T in (Dirichlet, Dirichlet{Float64}) + # Avoid issues with finite differencing if values in `alpha` become negative or zero + # by using forward differencing + test_frule(T, alpha; fdm=forward_fdm(5, 1)) + test_rrule(T, alpha; fdm=forward_fdm(5, 1)) + end + + @testset "_logpdf" begin + # `x1` is in the support, `x2` isn't + x1 = rand(n) + x1 ./= sum(x1) + x2 = x1 .+ 1 + + # Use special finite differencing method that tries to avoid moving outside of the + # support by limiting the range of the points around the input that are evaluated + fdm = central_fdm(5, 1; max_range=1e-9) + + for x in (x1, x2) + # We have to adjust the tolerance since the finite differencing method is rough + test_frule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true) + test_rrule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true) end end end