Skip to content

Commit b356da0

Browse files
mharradoniampritishpatildevmotion
authored
Add Convolve for DiscreteNonParametric (Redux) (#1850)
* Add convolve for DiscreteNonParametric DiscreteNonParametric convolution has a very nice trivial closed form. It was not implemented. This pull request implements it. * Update src/convolution.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Use Set, instead of splatting. Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Fix type stability of elements. Doesn't preserve the type of the Vector, but perhaps this is better .... Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Apply suggestions from code review use functions to access the support and probabilities, and write as one loop. Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Added a test set. Removed check args: We know the convovultion is a proper distribution. * minor rename for consistency * Formatting, test improvements suggested by devmotion (and a few more) * Formatting --------- Co-authored-by: iampritishpatil <iampritishpatil@gmail.com> Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent 6af1e2f commit b356da0

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/convolution.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and one of
1212
* [`NegativeBinomial`](@ref)
1313
* [`Geometric`](@ref)
1414
* [`Poisson`](@ref)
15+
* [`DiscreteNonParametric`](@ref)
1516
* [`Normal`](@ref)
1617
* [`Cauchy`](@ref)
1718
* [`Chisq`](@ref)
@@ -47,6 +48,19 @@ end
4748
convolve(d1::Poisson, d2::Poisson) = Poisson(d1.λ + d2.λ)
4849

4950

51+
function convolve(d1::DiscreteNonParametric, d2::DiscreteNonParametric)
52+
support_conv = collect(Set(s1 + s2 for s1 in support(d1), s2 in support(d2)))
53+
sort!(support_conv) #for fast index finding below
54+
probs1 = probs(d1)
55+
probs2 = probs(d2)
56+
p_conv = zeros(Base.promote_eltype(probs1, probs2), length(support_conv))
57+
for (s1, p1) in zip(support(d1), probs(d1)), (s2, p2) in zip(support(d2), probs(d2))
58+
idx = searchsortedfirst(support_conv, s1+s2)
59+
p_conv[idx] += p1*p2
60+
end
61+
DiscreteNonParametric(support_conv, p_conv,check_args=false)
62+
end
63+
5064
# continuous univariate
5165
convolve(d1::Normal, d2::Normal) = Normal(d1.μ + d2.μ, hypot(d1.σ, d2.σ))
5266
convolve(d1::Cauchy, d2::Cauchy) = Cauchy(d1.μ + d2.μ, d1.σ + d2.σ)

test/convolution.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,28 @@ using Test
6666
@test d3 isa Poisson
6767
@test d3.λ == 0.5
6868
end
69+
70+
@testset "DiscreteNonParametric" begin
71+
d1 = DiscreteNonParametric([0, 1], [0.5, 0.5])
72+
d2 = DiscreteNonParametric([1, 2], [0.5, 0.5])
73+
d_eps = DiscreteNonParametric([prevfloat(0.0), 0.0, nextfloat(0.0), 1.0], fill(1//4, 4))
74+
d10 = DiscreteNonParametric((1//10):(1//10):1, fill(1//10, 10))
75+
76+
d_int_simple = @inferred(convolve(d1, d2))
77+
@test d_int_simple isa DiscreteNonParametric
78+
@test support(d_int_simple) == [1, 2, 3]
79+
@test probs(d_int_simple) == [0.25, 0.5, 0.25]
80+
81+
d_rat = convolve(d10, d10)
82+
@test support(d_rat) == (1//5):(1//10):2
83+
@test probs(d_rat) == [1//100, 1//50, 3//100, 1//25, 1//20, 3//50, 7//100, 2//25, 9//100, 1//10,
84+
9//100, 2//25, 7//100, 3//50, 1//20, 1//25, 3//100, 1//50, 1//100]
85+
86+
d_float_supp = convolve(d_eps, d_eps)
87+
@test support(d_float_supp) == [2 * prevfloat(0.0), prevfloat(0.0), 0.0, nextfloat(0.0), 2 * nextfloat(0.0), 1.0, 2.0]
88+
@test probs(d_float_supp) == [1//16, 1//8, 3//16, 1//8, 1//16, 3//8, 1//16]
89+
end
90+
6991
end
7092

7193
@testset "continuous univariate" begin

0 commit comments

Comments
 (0)