Skip to content

Commit f74d096

Browse files
committed
More careful evaluation + add tests
1 parent 2a2cfaa commit f74d096

File tree

2 files changed

+64
-4
lines changed

2 files changed

+64
-4
lines changed

src/truncated/normal.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,34 @@ function entropy(d::Truncated{<:Normal{<:Real},Continuous})
118118
0.5 * (log2π + 1.) + log* z) + (aφa - bφb) / (2.0 * z)
119119
end
120120

121-
function mgf(d::Truncated{Normal{T},Continuous}, t::Real) where {T}
121+
function mgf(d::Truncated{<:Normal{<:Real},Continuous}, t::Real)
122+
a, b = extrema(d)
123+
T = promote_type(partype(d), typeof(t), typeof(a))
124+
if isnan(a) || isnan(b) # TODO: Disallow constructing `Truncated` with a `NaN` bound?
125+
return T(NaN)
126+
elseif isinf(a) && isinf(b) && a != b
127+
# Distribution is `Truncated`-wrapped but not actually truncated
128+
return T(mgf(d.untruncated, t))
129+
elseif a == b
130+
# Truncated to a Dirac distribution; this is `mgf(Dirac(a), t)`
131+
return exp(a * t)
132+
end
122133
d0 = d.untruncated
123134
μ = mean(d0)
124135
σ = std(d0)
125136
σ²t = σ^2 * t
126-
a = (minimum(d) - μ) / σ - σ²t
127-
b = (maximum(d) - μ) / σ - σ²t
137+
a = (a - μ) / σ
138+
b = (b - μ) / σ
128139
stdnorm = Normal{T}(zero(T), one(T))
129-
return exp(t *+ σ²t / 2) + logdiffcdf(stdnorm, b, a) - d.logtp)
140+
# log((Φ(b′ - σ²t) - Φ(a′ - σ²t)) / (Φ(b′) - Φ(a′)))
141+
logratio = if isfinite(a) && isfinite(b) # doubly truncated
142+
logdiffcdf(stdnorm, b′ - σ²t, a′ - σ²t) - logdiffcdf(stdnorm, b′, a′)
143+
elseif isfinite(a) # left truncated: b = ∞, Φ(b′) = Φ(b′ - σ²t) = 1
144+
logccdf(stdnorm, a′ - σ²t) - logccdf(stdnorm, a′)
145+
else # isfinite(b), right truncated: a = ∞, Φ(a′) = Φ(a′ - σ²t) = 0
146+
logcdf(stdnorm, b′ - σ²t) - logcdf(stdnorm, b′)
147+
end
148+
return exp(t *+ σ²t / 2) + logratio)
130149
end
131150

132151
### sampling

test/truncated/normal.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,44 @@ end
6969
@test isfinite(pdf(trunc, x))
7070
end
7171
end
72+
73+
bigly(x) = x
74+
bigly(x::Symbol) = x === || x === :ℯ ? Expr(:call, :big, x) : x
75+
bigly(x::Real) = Expr(:call, :big, x)
76+
bigly(x::Expr) = (map!(bigly, x.args, x.args); x)
77+
macro bigly(ex)
78+
return esc(bigly(ex))
79+
end
80+
81+
@testset "Truncated normal MGF" begin
82+
sqrt2 = sqrt(big(2))
83+
invsqrt2 = inv(sqrt2)
84+
inv2sqrt2 = inv(big(2) * sqrt2)
85+
twoerfsqrt2 = big(2) * erf(sqrt2)
86+
87+
for T in (Float32, Float64, BigFloat)
88+
d = truncated(Normal{T}(zero(T), one(T)), -2, 2)
89+
@test @inferred(mgf(d, 0)) == 1
90+
@test @inferred(mgf(d, 1)) @bigly sqrt(ℯ) * (erf(invsqrt2) + erf(3 * invsqrt2)) / twoerfsqrt2
91+
@test @inferred(mgf(d, 2.5)) @bigly exp(25//8) * (erf(9 * inv2sqrt2) - erf(inv2sqrt2)) / twoerfsqrt2
92+
end
93+
94+
d = truncated(Normal(3, 10), 7, 8)
95+
@test mgf(d, 0) == 1
96+
@test mgf(d, 1) == 0
97+
98+
d = truncated(Normal(27, 3); lower=0)
99+
@test mgf(d, 0) == 1
100+
@test mgf(d, 1) @bigly 2 * exp(63//2) / (1 + erf(9 * invsqrt2))
101+
@test mgf(d, 2.5) @bigly 2 * exp(765//8) / (1 + erf(9 * invsqrt2))
102+
103+
d = truncated(Normal(-5, 1); upper=-10)
104+
@test mgf(d, 0) == 1
105+
@test mgf(d, 1) @bigly erfc(3 * sqrt2) / (exp(9//2) * erfc(5 * invsqrt2))
106+
107+
@test isnan(mgf(truncated(Normal(); upper=NaN), 0))
108+
109+
@test mgf(truncated(Normal(), -Inf, Inf), 1) == mgf(Normal(), 1)
110+
111+
@test mgf(truncated(Normal(), 2, 2), 1) == exp(2)
112+
end

0 commit comments

Comments
 (0)