Skip to content

Commit 036a24d

Browse files
committed
Add mean et al. for truncated log normal
Fixes 709
1 parent 65f056c commit 036a24d

File tree

7 files changed

+163
-0
lines changed

7 files changed

+163
-0
lines changed

src/truncate.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ include(joinpath("truncated", "exponential.jl"))
261261
include(joinpath("truncated", "uniform.jl"))
262262
include(joinpath("truncated", "loguniform.jl"))
263263
include(joinpath("truncated", "discrete_uniform.jl"))
264+
include(joinpath("truncated", "lognormal.jl"))
264265

265266
#### Utilities
266267

src/truncated/lognormal.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Moments of the truncated log-normal can be computed directly from the moment generating
2+
# function of the truncated normal:
3+
# Let Y ~ LogNormal(μ, σ) truncated to (a, b). Then log(Y) ~ Normal(μ, σ) truncated
4+
# to (log(a), log(b)), and E[Y^n] = E[(e^log(Y))^n] = E[e^(nlog(Y))] = mgf(log(Y), n).
5+
6+
# Given `truncate(LogNormal(μ, σ), a, b)`, return `truncate(Normal(μ, σ), log(a), log(b))`
7+
function _truncnorm(d::Truncated{<:LogNormal})
8+
μ, σ = params(d.untruncated)
9+
T = partype(d)
10+
a = d.lower === nothing ? nothing : log(T(d.lower))
11+
b = d.upper === nothing ? nothing : log(T(d.upper))
12+
return truncated(Normal(μ, σ), a, b)
13+
end
14+
15+
mean(d::Truncated{<:LogNormal}) = mgf(_truncnorm(d), 1)
16+
17+
function var(d::Truncated{<:LogNormal})
18+
tn = _truncnorm(d)
19+
# Ensure the variance doesn't end up negative, which can occur due to numerical issues
20+
return max(mgf(tn, 2) - mgf(tn, 1)^2, 0)
21+
end
22+
23+
function skewness(d::Truncated{<:LogNormal})
24+
tn = _truncnorm(d)
25+
m1 = mgf(tn, 1)
26+
m2 = mgf(tn, 2)
27+
m3 = mgf(tn, 3)
28+
sqm1 = m1^2
29+
v = m2 - sqm1
30+
return (m3 + m1 * (-3 * m2 + 2 * sqm1)) / (v * sqrt(v))
31+
end
32+
33+
function kurtosis(d::Truncated{<:LogNormal})
34+
tn = _truncnorm(d)
35+
m1 = mgf(tn, 1)
36+
m2 = mgf(tn, 2)
37+
m3 = mgf(tn, 3)
38+
m4 = mgf(tn, 4)
39+
v = m2 - m1^2
40+
return @horner(m1, m4, -4m3, 6m2, 0, -3) / v^2 - 3
41+
end
42+
43+
# TODO: The entropy can be written "directly" as well, according to Mathematica, but
44+
# the expression for it fills me with regret. There are some recognizable components,
45+
# so a sufficiently motivated person could try to manually simplify it into something
46+
# comprehensible. For reference, you can obtain the entropy with Mathematica like so:
47+
#
48+
# d = TruncatedDistribution[{a, b}, LogNormalDistribution[m, s]];
49+
# Expectation[-LogLikelihood[d, {x}], Distributed[x, d],
50+
# Assumptions -> Element[x | m | s | a | b, Reals] && s > 0 && 0 < a < x < b]

src/truncated/normal.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,36 @@ function entropy(d::Truncated{<:Normal{<:Real},Continuous})
118118
0.5 * (log2π + 1.) + log* z) + (aφa - bφb) / (2.0 * z)
119119
end
120120

121+
function mgf(d::Truncated{<:Normal{<:Real},Continuous}, t::Real)
122+
T = float(promote_type(partype(d), typeof(t)))
123+
a = T(minimum(d))
124+
b = T(maximum(d))
125+
if isnan(a) || isnan(b) # TODO: Disallow constructing `Truncated` with a `NaN` bound?
126+
return T(NaN)
127+
elseif isinf(a) && isinf(b) && a != b
128+
# Distribution is `Truncated`-wrapped but not actually truncated
129+
return T(mgf(d.untruncated, t))
130+
elseif a == b
131+
# Truncated to a Dirac distribution; this is `mgf(Dirac(a), t)`
132+
return exp(a * t)
133+
end
134+
d0 = d.untruncated
135+
μ = mean(d0)
136+
σ = std(d0)
137+
σ²t = σ^2 * t
138+
a′ = (a - μ) / σ
139+
b′ = (b - μ) / σ
140+
stdnorm = Normal{T}(zero(T), one(T))
141+
# log((Φ(b′ - σ²t) - Φ(a′ - σ²t)) / (Φ(b′) - Φ(a′)))
142+
logratio = if isfinite(a) && isfinite(b) # doubly truncated
143+
logdiffcdf(stdnorm, b′ - σ²t, a′ - σ²t) - logdiffcdf(stdnorm, b′, a′)
144+
elseif isfinite(a) # left truncated: b = ∞, Φ(b′) = Φ(b′ - σ²t) = 1
145+
logccdf(stdnorm, a′ - σ²t) - logccdf(stdnorm, a′)
146+
else # isfinite(b), right truncated: a = ∞, Φ(a′) = Φ(a′ - σ²t) = 0
147+
logcdf(stdnorm, b′ - σ²t) - logcdf(stdnorm, b′)
148+
end
149+
return exp(t *+ σ²t / 2) + logratio)
150+
end
121151

122152
### sampling
123153

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const tests = [
2020
"truncated/exponential",
2121
"truncated/uniform",
2222
"truncated/discrete_uniform",
23+
"truncated/lognormal",
2324
"censored",
2425
"univariate/continuous/normal",
2526
"univariate/continuous/laplace",

test/testutils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ function _linspace(a::Float64, b::Float64, n::Int)
1818
return r
1919
end
2020

21+
# Enables testing against values computed at high precision by transforming an expression
22+
# that uses numeric literals and constants to wrap those in `big()`, similar to how the
23+
# high-precision values for irrational constants are defined with `Base.@irrational` and
24+
# in IrrationalConstants.jl. See e.g. `test/truncated/normal.jl` for example use.
25+
bigly(x) = x
26+
bigly(x::Symbol) = x in (, :ℯ, :Inf, :NaN) ? Expr(:call, :big, x) : x
27+
bigly(x::Real) = Expr(:call, :big, x)
28+
bigly(x::Expr) = (map!(bigly, x.args, x.args); x)
29+
macro bigly(ex)
30+
return esc(bigly(ex))
31+
end
2132

2233
#################################################
2334
#

test/truncated/lognormal.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using Distributions, Test
2+
using Distributions: expectation
3+
4+
naive_moment(d, n, μ, σ²) == sqrt(σ²); expectation(x -> ((x - μ) / σ)^n, d))
5+
6+
@testset "Truncated log normal" begin
7+
@testset "truncated(LogNormal{$T}(0, 1), ℯ⁻², ℯ²)" for T in (Float32, Float64, BigFloat)
8+
d = truncated(LogNormal{T}(zero(T), one(T)), exp(T(-2)), exp(T(2)))
9+
tn = truncated(Normal{BigFloat}(big(0.0), big(1.0)), -2, 2)
10+
bigmean = mgf(tn, 1)
11+
bigvar = mgf(tn, 2) - bigmean^2
12+
@test @inferred(mean(d)) bigmean
13+
@test @inferred(var(d)) bigvar
14+
@test @inferred(median(d)) one(T)
15+
@test @inferred(skewness(d)) naive_moment(d, 3, bigmean, bigvar)
16+
@test @inferred(kurtosis(d)) naive_moment(d, 4, bigmean, bigvar) - big(3)
17+
@test mean(d) isa T
18+
end
19+
@testset "Bound with no effect" begin
20+
# Uses the example distribution from issue #709, though what's tested here is
21+
# mostly unrelated to that issue (aside from `mean` not erroring).
22+
# The specified left truncation at 0 has no effect for `LogNormal`
23+
d1 = truncated(LogNormal(1, 5), 0, 1e5)
24+
@test mean(d1) 0 atol=eps()
25+
v1 = var(d1)
26+
@test v1 0 atol=eps()
27+
# Without a `max(_, 0)`, this would be within machine precision of 0 (as above) but
28+
# numerically negative, which could cause downstream issues that assume a nonnegative
29+
# variance
30+
@test v1 >= 0
31+
# Compare results with not specifying a lower bound at all
32+
d2 = truncated(LogNormal(1, 5); upper=1e5)
33+
@test mean(d1) == mean(d2)
34+
@test var(d1) == var(d2)
35+
end
36+
end

test/truncated/normal.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,37 @@ end
6969
@test isfinite(pdf(trunc, x))
7070
end
7171
end
72+
73+
@testset "Truncated normal MGF" begin
74+
two = big(2)
75+
sqrt2 = sqrt(two)
76+
invsqrt2 = inv(sqrt2)
77+
inv2sqrt2 = inv(two * sqrt2)
78+
twoerfsqrt2 = two * erf(sqrt2)
79+
80+
for T in (Float32, Float64, BigFloat)
81+
d = truncated(Normal{T}(zero(T), one(T)), -2, 2)
82+
@test @inferred(mgf(d, 0)) == 1
83+
@test @inferred(mgf(d, 1)) @bigly sqrt(ℯ) * (erf(invsqrt2) + erf(3 * invsqrt2)) / twoerfsqrt2
84+
@test @inferred(mgf(d, 2.5)) @bigly exp(25//8) * (erf(9 * inv2sqrt2) - erf(inv2sqrt2)) / twoerfsqrt2
85+
end
86+
87+
d = truncated(Normal(3, 10), 7, 8)
88+
@test mgf(d, 0) == 1
89+
@test mgf(d, 1) == 0
90+
91+
d = truncated(Normal(27, 3); lower=0)
92+
@test mgf(d, 0) == 1
93+
@test mgf(d, 1) @bigly 2 * exp(63//2) / (1 + erf(9 * invsqrt2))
94+
@test mgf(d, 2.5) @bigly 2 * exp(765//8) / (1 + erf(9 * invsqrt2))
95+
96+
d = truncated(Normal(-5, 1); upper=-10)
97+
@test mgf(d, 0) == 1
98+
@test mgf(d, 1) @bigly erfc(3 * sqrt2) / (exp(9//2) * erfc(5 * invsqrt2))
99+
100+
@test isnan(mgf(truncated(Normal(); upper=NaN), 0))
101+
102+
@test mgf(truncated(Normal(), -Inf, Inf), 1) == mgf(Normal(), 1)
103+
104+
@test mgf(truncated(Normal(), 2, 2), 1) == exp(2)
105+
end

0 commit comments

Comments
 (0)