Skip to content

Commit d5a293a

Browse files
committed
rrule tests
1 parent d407bf2 commit d5a293a

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/multivariate/dirichlet.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,9 @@ end
395395
function ChainRulesCore.rrule(DT::Union{Type{Dirichlet{T}}, Type{Dirichlet}}, alpha::AbstractVector{T}; check_args = true) where {T}
396396
d = DT(alpha; check_args=check_args)
397397
function dirichlet_pullback(d_dir)
398-
d_dir = ChainRulesCore.unthunk(alpha)
399-
@info typeof(d_dir)
400-
return (ChainRulesCore.NoTangent(), d_dir.alpha)
398+
d_dir = ChainRulesCore.unthunk(d_dir)
399+
∂l = d_dir.lmnB * (SpecialFunctions.digamma.(alpha) .- SpecialFunctions.digamma.(d.alpha0))
400+
return (ChainRulesCore.NoTangent(), d_dir.alpha .+ d_dir.alpha0 .+ ∂l)
401401
end
402402
return d, dirichlet_pullback
403403
end

test/dirichlet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ end
134134
alpha = rand(n)
135135
Δalpha = randn(n)
136136
d2, ∂d = ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
137-
ChainRulesTestUtils.test_frule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha, check_inferred=false)
137+
ChainRulesTestUtils.test_frule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha, check_inferred=true)
138138

139-
_, dp = ChainRulesCore.rrule(Dirichlet ChainRulesCore.NoTangent(), alpha Δalpha)
139+
_, dp = ChainRulesCore.rrule(Dirichlet, alpha)
140140
ChainRulesTestUtils.test_rrule(Dirichlet{Float64} ChainRulesCore.NoTangent(), alpha)
141141
end
142142
end

0 commit comments

Comments
 (0)