Skip to content

Commit 03269b4

Browse files
devmotionsimsuraceSimone Carlo Surace
authored
Add rrule for logpdf of NegativeBinomial (completes #1568) (#1579)
* Add `rrule` for `logpdf` of `NegativeBinomial` * Remove unnecessary module prefix Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Use explicit division Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Refator and correct pullback * Add tests for rrule * Use `forward_fdm` for testing `rrule` * Fix tests * Update test/negativebinomial.jl * Fix tests (without `p = 1 - eps()`) * Use FD for all tests, use random parameters * Avoid type instability Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Split and rearrange ForwardDiff and rrule tests * Bump version * Fix typo * Clean tests (revert unrelated changes) and fix them Co-authored-by: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Co-authored-by: Simone Carlo Surace <simone.surace@artificialy.com>
1 parent c430f39 commit 03269b4

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

src/univariate/discrete/negativebinomial.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,38 @@ function cf(d::NegativeBinomial, t::Real)
135135
r, p = params(d)
136136
return (((1 - p) * cis(t)) / (1 - p * cis(t)))^r
137137
end
138+
139+
# ChainRules definitions
140+
141+
function ChainRulesCore.rrule(::typeof(logpdf), d::NegativeBinomial, k::Real)
142+
# Compute log probability
143+
r, p = params(d)
144+
edgecase = isone(p) && iszero(k)
145+
insupp = insupport(d, k)
146+
147+
# Primal computation
148+
Ω = r * log(p) + k * log1p(-p)
149+
if edgecase
150+
Ω = zero(Ω)
151+
elseif !insupp
152+
Ω = oftype(Ω, -Inf)
153+
else
154+
Ω = Ω - log(k + r) - logbeta(r, k + 1)
155+
end
156+
157+
# Define pullback
158+
function logpdf_NegativeBinomial_pullback(Δ)
159+
Δr = Δ * (log(p) - inv(k + r) - digamma(r) + digamma(r + k + 1))
160+
Δp = Δ * (r / p - k / (1 - p))
161+
if edgecase
162+
Δp = oftype(Δp, Δ * r)
163+
elseif !insupp
164+
Δr = oftype(Δr, NaN)
165+
Δp = oftype(Δp, NaN)
166+
end
167+
Δd = ChainRulesCore.Tangent{typeof(d)}(; r=Δr, p=Δp)
168+
return ChainRulesCore.NoTangent(), Δd, ChainRulesCore.NoTangent()
169+
end
170+
171+
return Ω, logpdf_NegativeBinomial_pullback
172+
end

test/negativebinomial.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Distributions
22
using Test, ForwardDiff
3+
using ChainRulesTestUtils
4+
using FiniteDifferences
35

4-
# Currently, most of the tests for NegativeBinomail are in the "ref" folder.
6+
# Currently, most of the tests for NegativeBinomial are in the "ref" folder.
57
# Eventually, we might want to consolidate the tests here
68

79
mydiffp(r, p, k) = r/p - k/(1 - p)
@@ -19,3 +21,25 @@ end
1921
@test logpdf(NegativeBinomial(0.5, 1.0), 1) === -Inf
2022
@test all(iszero, rand(NegativeBinomial(rand(), 1.0), 10))
2123
end
24+
25+
@testset "rrule: logpdf of NegativeBinomial" begin
26+
r = randexp()
27+
28+
# Test with values in and outside of support
29+
p = rand()
30+
dist = NegativeBinomial(r, p)
31+
fdm = central_fdm(5, 1; max_range=min(r, p, 1-p)/2) # avoids numerical issues with finite differencing
32+
for k in (0, 10, 42, -1, -5, -13)
33+
# Test both integers and floating point numbers.
34+
# For floating point numbers we have to tell FiniteDifferences explicitly that the
35+
# argument is non-differentiable. Otherwise it will compute `NaN` as derivative.
36+
test_rrule(logpdf, dist, k; fdm=fdm, nans=true)
37+
test_rrule(logpdf, dist, float(k) ChainRulesTestUtils.NoTangent(); fdm=fdm, nans=true)
38+
end
39+
40+
# Test edge case `p = 1` and `k = 0`
41+
dist = NegativeBinomial(r, 1)
42+
fdm = backward_fdm(5, 1; max_range = r/10)
43+
test_rrule(logpdf, dist, 0; fdm=fdm)
44+
test_rrule(logpdf, dist, 0.0 ChainRulesTestUtils.NoTangent(); fdm=fdm)
45+
end

0 commit comments

Comments
 (0)