1
- function ChainRulesCore. frule ((_, Δalpha):: Tuple{Any,Any} , :: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
1
+ function ChainRulesCore. frule (:: ChainRulesCore.RuleConfig , (_, Δalpha):: Tuple{Any,Any} , :: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
2
2
d = DT (alpha; check_args= check_args)
3
3
∂alpha0 = sum (Δalpha)
4
4
digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
@@ -9,7 +9,7 @@ function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::A
9
9
return d, Δd
10
10
end
11
11
12
- function ChainRulesCore. rrule (:: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
12
+ function ChainRulesCore. rrule (:: ChainRulesCore.RuleConfig , :: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
13
13
d = DT (alpha; check_args= check_args)
14
14
digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
15
15
function Dirichlet_pullback (_Δd)
@@ -20,7 +20,7 @@ function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::
20
20
return d, Dirichlet_pullback
21
21
end
22
22
23
- function ChainRulesCore. frule ((_, Δd, Δx):: Tuple{Any,Any,Any} , :: typeof (Distributions. _logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
23
+ function ChainRulesCore. frule (:: ChainRulesCore.RuleConfig , (_, Δd, Δx):: Tuple{Any,Any,Any} , :: typeof (Distributions. _logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
24
24
Ω = Distributions. _logpdf (d, x)
25
25
∂alpha = sum (Broadcast. instantiate (Broadcast. broadcasted (Δd. alpha, Δx, d. alpha, x) do Δalphai, Δxi, alphai, xi
26
26
StatsFuns. xlogy (Δalphai, xi) + (alphai - 1 ) * Δxi / xi
@@ -33,7 +33,7 @@ function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distri
33
33
return Ω, ΔΩ
34
34
end
35
35
36
- function ChainRulesCore. rrule (:: typeof (Distributions. _logpdf), d:: T , x:: AbstractVector{<:Real} ) where {T<: Dirichlet }
36
+ function ChainRulesCore. rrule (:: ChainRulesCore.RuleConfig , :: typeof (Distributions. _logpdf), d:: T , x:: AbstractVector{<:Real} ) where {T<: Dirichlet }
37
37
Ω = Distributions. _logpdf (d, x)
38
38
isfinite_Ω = isfinite (Ω)
39
39
alpha = d. alpha
0 commit comments