Skip to content

Commit 08c56ea

Browse files
quildtidedevmotion
andauthored
Specialized vector rand! for many distributions (#1879)
* Test scalar rand separately from vector rand * Add specialized rand! for many distributions * Restore location of old NormalInverseGaussian tests * Remove duplication of inversegaussian in runtests.jl * Apply many suggestions from code review Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Apply other suggestions * Remove redundant new tests * Clean up more * Partially undo previous undo to changes to tests * Use xval for NormalCanon rand * Apply suggestions from code review Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Apply other recommendations to testutils * Fix erroneous ! * Address reviewer comments * `mean` not defined for `LogitNormal` * Copy RNG with `copy`, not `deepcopy` --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent b219803 commit 08c56ea

File tree

16 files changed

+160
-21
lines changed

16 files changed

+160
-21
lines changed

src/univariate/continuous/exponential.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ cf(d::Exponential, t::Real) = 1/(1 - t * im * scale(d))
107107
#### Sampling
108108
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, float(T)))
109109

110+
function rand!(rng::AbstractRNG, d::Exponential, A::AbstractArray{<:Real})
111+
randexp!(rng, A)
112+
map!(Base.Fix1(xval, d), A, A)
113+
return A
114+
end
110115

111116
#### Fit model
112117

src/univariate/continuous/logitnormal.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,14 @@ end
157157

158158
#### Sampling
159159

160-
rand(rng::AbstractRNG, d::LogitNormal) = logistic(randn(rng) * d.σ + d.μ)
160+
xval(d::LogitNormal, z::Real) = logistic(muladd(d.σ, z, d.μ))
161+
162+
rand(rng::AbstractRNG, d::LogitNormal) = xval(d, randn(rng))
163+
function rand!(rng::AbstractRNG, d::LogitNormal, A::AbstractArray{<:Real})
164+
randn!(rng, A)
165+
map!(Base.Fix1(xval, d), A, A)
166+
return A
167+
end
161168

162169
## Fitting
163170

src/univariate/continuous/lognormal.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,14 @@ end
156156

157157
#### Sampling
158158

159-
rand(rng::AbstractRNG, d::LogNormal) = exp(randn(rng) * d.σ + d.μ)
159+
xval(d::LogNormal, z::Real) = exp(muladd(d.σ, z, d.μ))
160+
161+
rand(rng::AbstractRNG, d::LogNormal) = xval(d, randn(rng))
162+
function rand!(rng::AbstractRNG, d::LogNormal, A::AbstractArray{<:Real})
163+
randn!(rng, A)
164+
map!(Base.Fix1(xval, d), A, A)
165+
return A
166+
end
160167

161168
## Fitting
162169

src/univariate/continuous/normal.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,14 @@ Base.:*(c::Real, d::Normal) = Normal(c * d.μ, abs(c) * d.σ)
114114

115115
#### Sampling
116116

117-
rand(rng::AbstractRNG, d::Normal{T}) where {T} = d.μ + d.σ * randn(rng, float(T))
117+
xval(d::Normal, z::Real) = muladd(d.σ, z, d.μ)
118118

119-
rand!(rng::AbstractRNG, d::Normal, A::AbstractArray{<:Real}) = A .= muladd.(d.σ, randn!(rng, A), d.μ)
119+
rand(rng::AbstractRNG, d::Normal{T}) where {T} = xval(d, randn(rng, float(T)))
120+
function rand!(rng::AbstractRNG, d::Normal, A::AbstractArray{<:Real})
121+
randn!(rng, A)
122+
map!(Base.Fix1(xval, d), A, A)
123+
return A
124+
end
120125

121126
#### Fitting
122127

src/univariate/continuous/normalcanon.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,13 @@ invlogccdf(d::NormalCanon, lp::Real) = xval(d, norminvlogccdf(lp))
8787

8888
#### Sampling
8989

90-
rand(rng::AbstractRNG, cf::NormalCanon) = cf.μ + randn(rng) / sqrt(cf.λ)
90+
rand(rng::AbstractRNG, cf::NormalCanon) = xval(cf, randn(rng))
91+
92+
function rand!(rng::AbstractRNG, cf::NormalCanon, A::AbstractArray{<:Real})
93+
randn!(rng, A)
94+
map!(Base.Fix1(xval, cf), A, A)
95+
return A
96+
end
9197

9298
#### Affine transformations
9399

src/univariate/continuous/pareto.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,14 @@ quantile(d::Pareto, p::Real) = cquantile(d, 1 - p)
110110

111111
#### Sampling
112112

113-
rand(rng::AbstractRNG, d::Pareto) = d.θ * exp(randexp(rng) / d.α)
113+
xval(d::Pareto, z::Real) = d.θ * exp(z / d.α)
114+
115+
rand(rng::AbstractRNG, d::Pareto) = xval(d, randexp(rng))
116+
function rand!(rng::AbstractRNG, d::Pareto, A::AbstractArray{<:Real})
117+
randexp!(rng, A)
118+
map!(Base.Fix1(xval, d), A, A)
119+
return A
120+
end
114121

115122
## Fitting
116123

src/univariate/continuous/pgeneralizedgaussian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ function rand(rng::AbstractRNG, d::PGeneralizedGaussian)
141141
inv_p = inv(d.p)
142142
g = Gamma(inv_p, 1)
143143
z = d.α * rand(rng, g)^inv_p
144-
if rand(rng) < 0.5
144+
if rand(rng, Bool)
145145
return d.μ - z
146146
else
147147
return d.μ + z

test/fit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ end
369369
for func in funcs, dist in (Laplace, Laplace{Float64})
370370
d = fit(dist, func[2](dist(5.0, 3.0), N + 1))
371371
@test isa(d, dist)
372-
@test isapprox(location(d), 5.0, atol=0.02)
372+
@test isapprox(location(d), 5.0, atol=0.03)
373373
@test isapprox(scale(d) , 3.0, atol=0.03)
374374
end
375375
end

test/multivariate/mvlognormal.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ end
105105
@test entropy(l1) entropy(l2)
106106
@test logpdf(l1,5.0) logpdf(l2,[5.0])
107107
@test pdf(l1,5.0) pdf(l2,[5.0])
108-
@test (Random.seed!(78393) ; [rand(l1)]) == (Random.seed!(78393) ; rand(l2))
109-
@test [rand(MersenneTwister(78393), l1)] == rand(MersenneTwister(78393), l2)
108+
@test (Random.seed!(78393) ; [rand(l1)]) (Random.seed!(78393) ; rand(l2))
109+
@test [rand(MersenneTwister(78393), l1)] rand(MersenneTwister(78393), l2)
110110
end
111111

112112
###### General Testing

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const tests = [
2222
"truncated/discrete_uniform",
2323
"censored",
2424
"univariate/continuous/normal",
25+
"univariate/continuous/normalcanon",
2526
"univariate/continuous/laplace",
2627
"univariate/continuous/cauchy",
2728
"univariate/continuous/uniform",
@@ -83,6 +84,7 @@ const tests = [
8384
"univariate/continuous/noncentralchisq",
8485
"univariate/continuous/weibull",
8586
"pdfnorm",
87+
"univariate/continuous/pareto",
8688
"univariate/continuous/rician",
8789
"functionals",
8890
"density_interface",
@@ -143,9 +145,7 @@ const tests = [
143145
# "univariate/continuous/levy",
144146
# "univariate/continuous/noncentralbeta",
145147
# "univariate/continuous/noncentralf",
146-
# "univariate/continuous/normalcanon",
147148
# "univariate/continuous/normalinversegaussian",
148-
# "univariate/continuous/pareto",
149149
# "univariate/continuous/rayleigh",
150150
# "univariate/continuous/studentizedrange",
151151
# "univariate/continuous/symtriangular",

0 commit comments

Comments
 (0)