Skip to content

Commit 8653c2d

Browse files
committed
Fixes and tests
1 parent f74d096 commit 8653c2d

File tree

6 files changed

+65
-23
lines changed

6 files changed

+65
-23
lines changed

src/truncated/lognormal.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,28 @@
66
# Given `truncate(LogNormal(μ, σ), a, b)`, return `truncate(Normal(μ, σ), log(a), log(b))`
77
function _truncnorm(d::Truncated{<:LogNormal})
88
μ, σ = params(d.untruncated)
9-
a = d.lower === nothing ? nothing : log(minimum(d))
10-
b = d.upper === nothing ? nothing : log(maximum(d))
9+
T = partype(d)
10+
a = d.lower === nothing ? nothing : log(T(minimum(d)))
11+
b = d.upper === nothing ? nothing : log(T(maximum(d)))
1112
return truncated(Normal(μ, σ), a, b)
1213
end
1314

1415
mean(d::Truncated{<:LogNormal}) = mgf(_truncnorm(d), 1)
1516

1617
function var(d::Truncated{<:LogNormal})
1718
tn = _truncnorm(d)
18-
m1 = mgf(tn, 1)
19-
m2 = sqrt(mgf(tn, 2))
20-
return (m2 - m1) * (m2 + m1)
19+
# Ensure the variance doesn't end up negative, which can occur due to numerical issues
20+
return max(mgf(tn, 2) - mgf(tn, 1)^2, 0)
2121
end
2222

2323
function skewness(d::Truncated{<:LogNormal})
2424
tn = _truncnorm(d)
2525
m1 = mgf(tn, 1)
26-
m2 = sqrt(mgf(tn, 2))
26+
m2 = mgf(tn, 2)
2727
m3 = mgf(tn, 3)
28-
v = (m2 - m1) * (m2 + m1)
29-
return (m3 - 3 * m1 * v - m1^3) / (v * sqrt(v))
28+
sqm1 = m1^2
29+
v = m2 - sqm1
30+
return (m3 + m1 * (-3 * m2 + 2 * sqm1)) / (v * sqrt(v))
3031
end
3132

3233
function kurtosis(d::Truncated{<:LogNormal})
@@ -35,8 +36,7 @@ function kurtosis(d::Truncated{<:LogNormal})
3536
m2 = mgf(tn, 2)
3637
m3 = mgf(tn, 3)
3738
m4 = mgf(tn, 4)
38-
sm2 = sqrt(m2)
39-
v = (sm2 - m1) * (sm2 + m1)
39+
v = m2 - m1^2
4040
return evalpoly(m1, (m4, -4m3, 6m2, 0, -3)) / v^2 - 3
4141
end
4242

src/truncated/normal.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ function entropy(d::Truncated{<:Normal{<:Real},Continuous})
119119
end
120120

121121
function mgf(d::Truncated{<:Normal{<:Real},Continuous}, t::Real)
122-
a, b = extrema(d)
123-
T = promote_type(partype(d), typeof(t), typeof(a))
122+
T = promote_type(partype(d), typeof(t))
123+
a = T(minimum(d))
124+
b = T(maximum(d))
124125
if isnan(a) || isnan(b) # TODO: Disallow constructing `Truncated` with a `NaN` bound?
125126
return T(NaN)
126127
elseif isinf(a) && isinf(b) && a != b

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const tests = [
2020
"truncated/exponential",
2121
"truncated/uniform",
2222
"truncated/discrete_uniform",
23+
"truncated/lognormal",
2324
"censored",
2425
"univariate/continuous/normal",
2526
"univariate/continuous/laplace",

test/testutils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ function _linspace(a::Float64, b::Float64, n::Int)
1818
return r
1919
end
2020

21+
# Enables testing against values computed at high precision by transforming an expression
22+
# that uses numeric literals and constants to wrap those in `big()`, similar to how the
23+
# high-precision values for irrational constants are defined with `Base.@irrational` and
24+
# in IrrationalConstants.jl. See e.g. `test/truncated/normal.jl` for example use.
25+
bigly(x) = x
26+
bigly(x::Symbol) = x in (, :ℯ, :Inf, :NaN) ? Expr(:call, :big, x) : x
27+
bigly(x::Real) = Expr(:call, :big, x)
28+
bigly(x::Expr) = (map!(bigly, x.args, x.args); x)
29+
macro bigly(ex)
30+
return esc(bigly(ex))
31+
end
2132

2233
#################################################
2334
#

test/truncated/lognormal.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using Distributions, Test
2+
using Distributions: expectation
3+
4+
naive_moment(d, n, μ, σ²) == sqrt(σ²); expectation(x -> ((x - μ) / σ)^n, d))
5+
6+
@testset "Truncated log normal" begin
7+
@testset "truncated(LogNormal{$T}(0, 1), ℯ⁻², ℯ²)" for T in (Float32, Float64, BigFloat)
8+
d = truncated(LogNormal{T}(zero(T), one(T)), exp(T(-2)), exp(T(2)))
9+
tn = truncated(Normal{BigFloat}(big(0.0), big(1.0)), -2, 2)
10+
bigmean = mgf(tn, 1)
11+
bigvar = mgf(tn, 2) - bigmean^2
12+
@test @inferred(mean(d)) bigmean
13+
@test @inferred(var(d)) bigvar
14+
@test @inferred(median(d)) one(T)
15+
@test @inferred(skewness(d)) naive_moment(d, 3, bigmean, bigvar)
16+
@test @inferred(kurtosis(d)) naive_moment(d, 4, bigmean, bigvar) - big(3)
17+
@test mean(d) isa T
18+
end
19+
@testset "Bound with no effect" begin
20+
# Uses the example distribution from issue #709, though what's tested here is
21+
# mostly unrelated to that issue (aside from `mean` not erroring).
22+
# The specified left truncation at 0 has no effect for `LogNormal`
23+
d1 = truncated(LogNormal(1, 5), 0, 1e5)
24+
@test mean(d1) 0 atol=eps()
25+
v1 = var(d1)
26+
@test v1 0 atol=eps()
27+
# Without a `max(_, 0)`, this would be within machine precision of 0 (as above) but
28+
# numerically negative, which could cause downstream issues that assume a nonnegative
29+
# variance
30+
@test v1 > 0
31+
# Compare results with not specifying a lower bound at all
32+
d2 = truncated(LogNormal(1, 5); upper=1e5)
33+
@test mean(d1) == mean(d2)
34+
@test var(d1) == var(d2)
35+
end
36+
end

test/truncated/normal.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,12 @@ end
7070
end
7171
end
7272

73-
bigly(x) = x
74-
bigly(x::Symbol) = x === || x === :ℯ ? Expr(:call, :big, x) : x
75-
bigly(x::Real) = Expr(:call, :big, x)
76-
bigly(x::Expr) = (map!(bigly, x.args, x.args); x)
77-
macro bigly(ex)
78-
return esc(bigly(ex))
79-
end
80-
8173
@testset "Truncated normal MGF" begin
82-
sqrt2 = sqrt(big(2))
74+
two = big(2)
75+
sqrt2 = sqrt(two)
8376
invsqrt2 = inv(sqrt2)
84-
inv2sqrt2 = inv(big(2) * sqrt2)
85-
twoerfsqrt2 = big(2) * erf(sqrt2)
77+
inv2sqrt2 = inv(two * sqrt2)
78+
twoerfsqrt2 = two * erf(sqrt2)
8679

8780
for T in (Float32, Float64, BigFloat)
8881
d = truncated(Normal{T}(zero(T), one(T)), -2, 2)

0 commit comments

Comments
 (0)