Skip to content

Commit 3946acc

Browse files
fix type stability of sampling from Chisq, TDist, Gamma (JuliaStats#1885)
* fix type stability of sampling from `Chisq`, `TDist`, `Gamma` * fix remove type specification in `rand(Exponential)` Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * fix type specificaton in `rand(TDist)` Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * fix remove type test for `rand(Chisq)` * fix make `Exponential` use the `Normal` sampling type policy * fix missing type signature * fix type signature for `rand(Exponential)` Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * fix use `@inferred` in tests for `Gamma` Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * fix use `@inferred` in tests for `TDist` Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * add type stability tests for `rand(Exponential)` * add type stability test for `rand(Chisq)` * fix remove type stability test for `entropy(TDist)` (not stable) --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent 13029c0 commit 3946acc

File tree

7 files changed

+56
-28
lines changed

7 files changed

+56
-28
lines changed

src/samplers/gamma.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,6 @@ end
225225

226226
function rand(rng::AbstractRNG, s::GammaIPSampler)
227227
x = rand(rng, s.s)
228-
e = randexp(rng)
228+
e = randexp(rng, typeof(x))
229229
x*exp(s.nia*e)
230230
end

src/univariate/continuous/exponential.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ cf(d::Exponential, t::Real) = 1/(1 - t * im * scale(d))
105105

106106

107107
#### Sampling
108-
rand(rng::AbstractRNG, d::Exponential) = xval(d, randexp(rng))
108+
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, float(T)))
109109

110110

111111
#### Fit model

src/univariate/continuous/tdist.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ end
8282
function rand(rng::AbstractRNG, d::TDist)
8383
ν = d.ν
8484
z = sqrt(rand(rng, Chisq{typeof(ν)}(ν)) / ν)
85-
return randn(rng) / (isinf(ν) ? one(z) : z)
85+
return randn(rng, typeof(z)) / (isinf(ν) ? one(z) : z)
8686
end
8787

8888
function cf(d::TDist{T}, t::Real) where T <: Real

test/univariate/continuous/chisq.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,9 @@
1-
test_cgf(Chisq(1), (0.49, -1, -100, -1f6))
2-
test_cgf(Chisq(3), (0.49, -1, -100, -1f6))
1+
2+
@testset "Chisq" begin
3+
test_cgf(Chisq(1), (0.49, -1, -100, -1.0f6))
4+
test_cgf(Chisq(3), (0.49, -1, -100, -1.0f6))
5+
6+
for T in (Float32, Float64)
7+
@test @inferred(rand(Chisq(T(1)))) isa T
8+
end
9+
end
Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11

2-
test_cgf(Exponential(1), (0.9, -1, -100f0, -1e6))
3-
test_cgf(Exponential(0.91), (0.9, -1, -100f0, -1e6))
4-
test_cgf(Exponential(10 ), (0.08, -1, -100f0, -1e6))
2+
@testset "Exponential" begin
3+
test_cgf(Exponential(1), (0.9, -1, -100f0, -1e6))
4+
test_cgf(Exponential(0.91), (0.9, -1, -100f0, -1e6))
5+
test_cgf(Exponential(10 ), (0.08, -1, -100f0, -1e6))
6+
7+
for T in (Float32, Float64)
8+
@test @inferred(rand(Exponential(T(1)))) isa T
9+
end
10+
end

test/univariate/continuous/gamma.jl

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,33 @@
11
using Test, Distributions, OffsetArrays
22

3-
test_cgf(Gamma(1 ,1 ), (0.9, -1, -100f0, -1e6))
4-
test_cgf(Gamma(10 ,1 ), (0.9, -1, -100f0, -1e6))
5-
test_cgf(Gamma(0.2, 10), (0.08, -1, -100f0, -1e6))
3+
@testset "Gamma" begin
4+
test_cgf(Gamma(1, 1), (0.9, -1, -100.0f0, -1e6))
5+
test_cgf(Gamma(10, 1), (0.9, -1, -100.0f0, -1e6))
6+
test_cgf(Gamma(0.2, 10), (0.08, -1, -100.0f0, -1e6))
67

7-
@testset "Gamma suffstats and OffsetArrays" begin
8-
a = rand(Gamma(), 11)
9-
wa = 1.0:11.0
8+
@testset "Gamma suffstats and OffsetArrays" begin
9+
a = rand(Gamma(), 11)
10+
wa = 1.0:11.0
1011

11-
resulta = @inferred(suffstats(Gamma, a))
12+
resulta = @inferred(suffstats(Gamma, a))
1213

13-
resultwa = @inferred(suffstats(Gamma, a, wa))
14+
resultwa = @inferred(suffstats(Gamma, a, wa))
1415

15-
b = OffsetArray(a, -5:5)
16-
wb = OffsetArray(wa, -5:5)
16+
b = OffsetArray(a, -5:5)
17+
wb = OffsetArray(wa, -5:5)
1718

18-
resultb = @inferred(suffstats(Gamma, b))
19-
@test resulta == resultb
19+
resultb = @inferred(suffstats(Gamma, b))
20+
@test resulta == resultb
2021

21-
resultwb = @inferred(suffstats(Gamma, b, wb))
22-
@test resultwa == resultwb
22+
resultwb = @inferred(suffstats(Gamma, b, wb))
23+
@test resultwa == resultwb
2324

24-
@test_throws DimensionMismatch suffstats(Gamma, a, wb)
25+
@test_throws DimensionMismatch suffstats(Gamma, a, wb)
26+
end
27+
28+
for T in (Float32, Float64)
29+
@test @inferred(rand(Gamma(T(1), T(1)))) isa T
30+
@test @inferred(rand(Gamma(1/T(2), T(1)))) isa T
31+
@test @inferred(rand(Gamma(T(2), T(1)))) isa T
32+
end
2533
end

test/univariate/continuous/tdist.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@ using ForwardDiff
33

44
using Test
55

6-
@testset "Type stability of `rand` (#1614)" begin
7-
if VERSION >= v"1.9.0-DEV.348"
8-
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
9-
@inferred(rand(TDist(big"1.0")))
6+
@testset "TDist" begin
7+
@testset "Type stability of `rand` (#1614)" begin
8+
if VERSION >= v"1.9.0-DEV.348"
9+
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
10+
@inferred(rand(TDist(big"1.0")))
11+
end
12+
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))
13+
14+
end
15+
16+
for T in (Float32, Float64)
17+
@test @inferred(rand(TDist(T(1)))) isa T
1018
end
11-
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))
1219
end

0 commit comments

Comments
 (0)