@@ -377,58 +377,60 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
377
377
end
378
378
379
379
# # Differentiation
380
- function ChainRulesCore. frule ((_, Δalpha), :: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
380
+ function ChainRulesCore. frule ((_, Δalpha):: Tuple{Any,Any} , :: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
381
381
d = DT (alpha; check_args= check_args)
382
- Δalpha = ChainRulesCore. unthunk (Δalpha)
383
382
∂alpha0 = sum (Δalpha)
384
383
digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
385
- ∂lmnB = sum (Broadcast. instantiate (Broadcast. broadcasted (Δalpha, alpha) do Δalpha_i, alpha_i
386
- Δalpha_i * (SpecialFunctions. digamma (alpha_i ) - digamma_alpha0)
384
+ ∂lmnB = sum (Broadcast. instantiate (Broadcast. broadcasted (Δalpha, alpha) do Δalphai, alphai
385
+ Δalphai * (SpecialFunctions. digamma (alphai ) - digamma_alpha0)
387
386
end ))
388
- backing = (alpha= Δalpha, alpha0= ∂alpha0, lmnB= ∂lmnB)
389
- t = ChainRulesCore. Tangent {typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}} (backing)
390
- return d, t
387
+ Δd = ChainRulesCore. Tangent {typeof(d)} (; alpha= Δalpha, alpha0= ∂alpha0, lmnB= ∂lmnB)
388
+ return d, Δd
391
389
end
392
390
393
391
function ChainRulesCore. rrule (:: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
394
392
d = DT (alpha; check_args= check_args)
395
- function dirichlet_pullback (d_dir )
396
- d_dir = ChainRulesCore . unthunk (d_dir )
397
- digamma_alpha0 = SpecialFunctions . digamma (d . alpha0 )
398
- dalpha = d_dir . alpha .+ d_dir . alpha0 .+ d_dir . lmnB .* (SpecialFunctions. digamma .(alpha) .- digamma_alpha0)
399
- return ChainRulesCore. NoTangent (), dalpha
393
+ digamma_alpha0 = SpecialFunctions . digamma (d . alpha0 )
394
+ function Dirichlet_pullback (_Δd )
395
+ Δd = ChainRulesCore . unthunk (_Δd )
396
+ Δalpha = Δd . alpha .+ Δd . alpha0 .+ Δd . lmnB .* (SpecialFunctions. digamma .(alpha) .- digamma_alpha0)
397
+ return ChainRulesCore. NoTangent (), Δalpha
400
398
end
401
- return d, dirichlet_pullback
399
+ return d, Dirichlet_pullback
402
400
end
403
401
404
- function ChainRulesCore. frule ((_, Δd, Δx), :: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
405
- lp = _logpdf (d, x)
406
- ∂α_x = sum (Broadcast. instantiate (Broadcast. broadcasted (Δd. alpha, Δx, d. alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i
407
- xlogy (Δalpha_i, x_i ) + (alpha_i - 1 ) * Δx_i / x_i
402
+ function ChainRulesCore. frule ((_, Δd, Δx):: Tuple{Any,Any,Any} , :: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
403
+ Ω = _logpdf (d, x)
404
+ ∂alpha = sum (Broadcast. instantiate (Broadcast. broadcasted (Δd. alpha, Δx, d. alpha, x) do Δalphai, Δxi, alphai, xi
405
+ xlogy (Δalphai, xi ) + (alphai - 1 ) * Δxi / xi
408
406
end ))
409
- ∂l = - Δd. lmnB
410
- if ! insupport (d, x)
411
- ∂α_x = oftype (∂α_x, NaN )
407
+ ∂lmnB = - Δd. lmnB
408
+ ΔΩ = ∂alpha + ∂lmnB
409
+ if ! isfinite (Ω)
410
+ ΔΩ = oftype (ΔΩ, NaN )
412
411
end
413
- return (lp, ∂α_x + ∂l)
412
+ return Ω, ΔΩ
414
413
end
415
414
416
- function ChainRulesCore. rrule (:: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
417
- y = _logpdf (d, x)
418
- function Dirichlet_logpdf_pullback (dy)
419
- ∂alpha = xlogy .(dy, x)
420
- ∂l = - dy
421
- ∂x = dy .* (d. alpha .- 1 ) ./ x
422
- ∂alpha0 = sum (∂alpha)
423
- if ! isfinite (y)
424
- ∂alpha = oftype (eltype (∂alpha), NaN ) * ∂alpha
425
- ∂l = oftype (∂l, NaN )
426
- ∂x = oftype (eltype (∂x), NaN ) * ∂x
427
- ∂alpha0 = oftype (eltype (∂alpha), NaN )
428
- end
429
- backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB= ∂l)
430
- ∂d = ChainRulesCore. Tangent {typeof(d), typeof(backing)} (backing)
431
- return (ChainRulesCore. NoTangent (), ∂d, ∂x)
415
+ function ChainRulesCore. rrule (:: typeof (_logpdf), d:: T , x:: AbstractVector{<:Real} ) where {T<: Dirichlet }
416
+ Ω = _logpdf (d, x)
417
+ isfinite_Ω = isfinite (Ω)
418
+ alpha = d. alpha
419
+ function _logpdf_Dirichlet_pullback (_ΔΩ)
420
+ ΔΩ = ChainRulesCore. unthunk (_ΔΩ)
421
+ ∂alpha = _logpdf_Dirichlet_∂alphai .(x, ΔΩ, isfinite_Ω)
422
+ ∂lmnB = isfinite_Ω ? - float (ΔΩ) : oftype (float (ΔΩ), NaN )
423
+ Δd = ChainRulesCore. Tangent {T} (; alpha= ∂alpha, lmnB= ∂lmnB)
424
+ Δx = _logpdf_Dirichlet_Δxi .(ΔΩ, alpha, x, isfinite_Ω)
425
+ return ChainRulesCore. NoTangent (), Δd, Δx
432
426
end
433
- return (y, Dirichlet_logpdf_pullback)
427
+ return Ω, _logpdf_Dirichlet_pullback
428
+ end
429
+ function _logpdf_Dirichlet_∂alphai (xi, ΔΩi, isfinite:: Bool )
430
+ ∂alphai = xlogy .(ΔΩi, xi)
431
+ return isfinite ? ∂alphai : oftype (∂alphai, NaN )
432
+ end
433
+ function _logpdf_Dirichlet_Δxi (ΔΩi, alphai, xi, isfinite:: Bool )
434
+ Δxi = ΔΩi * (alphai - 1 ) / xi
435
+ return isfinite ? Δxi : oftype (Δxi, NaN )
434
436
end
0 commit comments