Skip to content

Commit d407bf2

Browse files
committed
frule tested
1 parent ae0a0ad commit d407bf2

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/multivariate/dirichlet.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,23 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
381381
end
382382

383383
## Differentiation
384+
using Test
384385
function ChainRulesCore.frule((_, Δalpha), DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T}
385386
d = DT(alpha; check_args=check_args)
386387
Δalpha = ChainRulesCore.unthunk(Δalpha)
387388
∂alpha0 = sum(Δalpha)
388-
∂lmnB = (sum(SpecialFunctions.digamma(αi) for αi in alpha) - SpecialFunctions.digamma(d.alpha0)) * Δalpha
389-
return d, ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
389+
∂lmnB::typeof(∂alpha0) = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - SpecialFunctions.digamma(d.alpha0)) for i in eachindex(alpha))
390+
backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
391+
t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing)
392+
return d, t
393+
end
394+
395+
function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T}
396+
d = DT(alpha; check_args=check_args)
397+
function dirichlet_pullback(d_dir)
398+
d_dir = ChainRulesCore.unthunk(alpha)
399+
@info typeof(d_dir)
400+
return (ChainRulesCore.NoTangent(), d_dir.alpha)
401+
end
402+
return d, dirichlet_pullback
390403
end

test/dirichlet.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ end
134134
alpha = rand(n)
135135
Δalpha = randn(n)
136136
d2, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
137-
ChainRulesTestUtils.test_frule(Dirichlet{Float64}, alpha Δalpha)
137+
ChainRulesTestUtils.test_frule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha, check_inferred=false)
138+
139+
_, dp = ChainRulesCore.rrule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha)
140+
ChainRulesTestUtils.test_rrule(Dirichlet{Float64} ChainRulesCore.NoTangent(), alpha)
138141
end
139142
end

0 commit comments

Comments
 (0)