Skip to content

Commit 1d79fec

Browse files
committed
conflict
1 parent 96883e8 commit 1d79fec

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

src/multivariate/dirichlet.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -377,17 +377,17 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
377377
end
378378

379379
## Differentiation
380-
function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T}
380+
function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
381381
d = DT(alpha; check_args=check_args)
382382
Δalpha = ChainRulesCore.unthunk(Δalpha)
383383
∂alpha0 = sum(Δalpha)
384-
∂lmnB::typeof(∂alpha0) = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - SpecialFunctions.digamma(d.alpha0)) for i in eachindex(alpha))
384+
∂lmnB = sum(Δalpha[i] * (SpecialFunctions.digamma(alpha[i]) - SpecialFunctions.digamma(d.alpha0)) for i in eachindex(alpha))
385385
backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
386386
t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing)
387387
return d, t
388388
end
389389

390-
function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T}
390+
function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
391391
d = DT(alpha; check_args=check_args)
392392
function dirichlet_pullback(d_dir)
393393
d_dir = ChainRulesCore.unthunk(d_dir)
@@ -426,12 +426,3 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector
426426
end
427427
return (y, Dirichlet_logpdf_pullback)
428428
end
429-
430-
function _logpdf(d::Dirichlet, x::AbstractVector{<:Real})
431-
if !insupport(d, x)
432-
return xlogy(one(eltype(d.alpha)), zero(eltype(x))) - d.lmnB
433-
end
434-
a = d.alpha
435-
s = sum(xlogy(αi - 1, xi) for (αi, xi) in zip(d.alpha, x))
436-
return s - d.lmnB
437-
end

test/dirichlet.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Distributions
44
using Test, Random, LinearAlgebra
55
using ChainRulesCore
66
using ChainRulesTestUtils
7+
using FiniteDifferences
78

89
Random.seed!(34567)
910

@@ -132,7 +133,7 @@ end
132133
@testset "Dirichlet differentiation $n" for n in (2, 10)
133134
alpha = rand(n)
134135
Δalpha = randn(n)
135-
d, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
136+
d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
136137
ChainRulesTestUtils.test_frule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1))
137138
_, dp = ChainRulesCore.rrule(Dirichlet, alpha)
138139
ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha)

0 commit comments

Comments
 (0)