From 5332f3a7c73cb22ce1fcd59918234a7041602dad Mon Sep 17 00:00:00 2001 From: quildtide <42811940+quildtide@users.noreply.github.com> Date: Thu, 15 Aug 2024 20:19:18 -0400 Subject: [PATCH 1/9] Fix Dirichlet rand overflows #1702 --- src/multivariate/dirichlet.jl | 46 ++++++++++++++++++++++++++++++++-- test/multivariate/dirichlet.jl | 25 ++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index b24980ec98..bf5469a696 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -160,14 +160,56 @@ function _rand!(rng::AbstractRNG, for (i, αi) in zip(eachindex(x), d.alpha) @inbounds x[i] = rand(rng, Gamma(αi)) end - lmul!(inv(sum(x)), x) # this returns x + Σ = sum(x) + if Σ == 0.0 + # Distribution behavior approaches categorical as Σα -> 0 + α = d.alpha + iΣα = inv(d.alpha0) + if isinf(iΣα) + # Dirichlet with ALL deeply subnormal parameters + α .*= floatmax(eltype(α)) + iΣα = inv(sum(α)) + end + x[rand(rng, Categorical(iΣα .* α))] = 1 + return x + end + + iΣ = inv(Σ) + if isinf(iΣ) + # Σ is deep subnormal + x .*= floatmax(eltype(x)) + iΣ = inv(sum(x)) + end + + lmul!(iΣ, x) # this returns x end function _rand!(rng::AbstractRNG, d::Dirichlet{T,<:FillArrays.AbstractFill{T}}, x::AbstractVector{<:Real}) where {T<:Real} rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) - lmul!(inv(sum(x)), x) # this returns x + Σ = sum(x) + if Σ == 0.0 + # Distribution behavior approaches categorical as Σα -> 0 + α = d.alpha + iΣα = inv(alpha0) + if isinf(iΣα) + # Dirichlet with ALL deeply subnormal parameters + α .*= floatmax(eltype(α)) + iΣα = inv(sum(α)) + end + x[rand(rng, Categorical(iΣα .* α))] = 1 + return x + end + + iΣ = inv(Σ) + if isinf(iΣ) + # Σ is deep subnormal + x .*= floatmax(eltype(x)) + iΣ = inv(sum(x)) + end + + lmul!(iΣ, x) # this returns x end ####################################### diff --git a/test/multivariate/dirichlet.jl b/test/multivariate/dirichlet.jl index 78de162dca..7f6fd443b3 100644 --- a/test/multivariate/dirichlet.jl +++ b/test/multivariate/dirichlet.jl @@ -158,3 +158,28 @@ end end end end + +@testset "Dirichlet rand Inf and NaN (#1702)" begin + for d in [ + Dirichlet([8e-5, 1e-5, 2e-5]), + Dirichlet([8e-4, 1e-4, 2e-4]), + Dirichlet([4.5e-5, 8e-5]), + Dirichlet([6e-5, 2e-5, 3e-5, 4e-5, 5e-5]), + Dirichlet(FillArrays.Fill(1e-5, 5)) + ] + x = rand(d, 10^6) + @test mean(x, dims = 2) ≈ mean(d) atol=0.01 + @test var(x, dims = 2) ≈ var(d) atol=0.01 + end + + for (d, μ) in [ # Subnormal params cause mean(d) to error + (Dirichlet([5e-321, 1e-321, 4e-321]), [.5, .1, .4]), + (Dirichlet([1e-321, 2e-321, 3e-321, 4e-321]), [.1, .2, .3, .4]) + ] + x = rand(d, 10^6) + @test mean(x, dims = 2) ≈ μ atol=0.01 + end + + # Should equal [0.625061099164708, 0.37493890083529186, 0] on Julia v1.11.0-rc1 + @test sum(rand(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]))) == 1 +end \ No newline at end of file From 1b39c0c5ebce223e09dc87984de6b6b07ba728ac Mon Sep 17 00:00:00 2001 From: quildtide <42811940+quildtide@users.noreply.github.com> Date: Thu, 15 Aug 2024 20:59:30 -0400 Subject: [PATCH 2/9] Refactor code to reduce duplication --- src/multivariate/dirichlet.jl | 43 ++++++++++++----------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index bf5469a696..fc49bac3e2 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -154,12 +154,11 @@ end # sampling -function _rand!(rng::AbstractRNG, - d::Union{Dirichlet,DirichletCanon}, - x::AbstractVector{<:Real}) - for (i, αi) in zip(eachindex(x), d.alpha) - @inbounds x[i] = rand(rng, Gamma(αi)) - end +function _rand_handle_overflow!( + rng::AbstractRNG, + d::Union{Dirichlet,DirichletCanon}, + x::AbstractVector{<:Real} + ) Σ = sum(x) if Σ == 0.0 # Distribution behavior approaches categorical as Σα -> 0 @@ -184,32 +183,20 @@ function _rand!(rng::AbstractRNG, lmul!(iΣ, x) # this returns x end +function _rand!(rng::AbstractRNG, + d::Union{Dirichlet,DirichletCanon}, + x::AbstractVector{<:Real}) + for (i, αi) in zip(eachindex(x), d.alpha) + @inbounds x[i] = rand(rng, Gamma(αi)) + end + _rand_handle_overflow!(rng, d, x) +end + function _rand!(rng::AbstractRNG, d::Dirichlet{T,<:FillArrays.AbstractFill{T}}, x::AbstractVector{<:Real}) where {T<:Real} rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) - Σ = sum(x) - if Σ == 0.0 - # Distribution behavior approaches categorical as Σα -> 0 - α = d.alpha - iΣα = inv(alpha0) - if isinf(iΣα) - # Dirichlet with ALL deeply subnormal parameters - α .*= floatmax(eltype(α)) - iΣα = inv(sum(α)) - end - x[rand(rng, Categorical(iΣα .* α))] = 1 - return x - end - - iΣ = inv(Σ) - if isinf(iΣ) - # Σ is deep subnormal - x .*= floatmax(eltype(x)) - iΣ = inv(sum(x)) - end - - lmul!(iΣ, x) # this returns x + _rand_handle_overflow!(rng, d, x) end ####################################### From d1baaf44bb8703f2537209afdc7a7b39fc117259 Mon Sep 17 00:00:00 2001 From: quildtide <42811940+quildtide@users.noreply.github.com> Date: Thu, 15 Aug 2024 22:51:15 -0400 Subject: [PATCH 3/9] Remove test that requires Xoshiro --- test/multivariate/dirichlet.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/multivariate/dirichlet.jl b/test/multivariate/dirichlet.jl index 7f6fd443b3..7797888866 100644 --- a/test/multivariate/dirichlet.jl +++ b/test/multivariate/dirichlet.jl @@ -179,7 +179,4 @@ end x = rand(d, 10^6) @test mean(x, dims = 2) ≈ μ atol=0.01 end - - # Should equal [0.625061099164708, 0.37493890083529186, 0] on Julia v1.11.0-rc1 - @test sum(rand(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]))) == 1 end \ No newline at end of file From 735324b78abf8d5d1c38b44ff0a46b97c82144c2 Mon Sep 17 00:00:00 2001 From: quildtide <42811940+quildtide@users.noreply.github.com> Date: Wed, 4 Sep 2024 00:41:34 -0400 Subject: [PATCH 4/9] Implement ExpGammaIPSampler Co-Authored-By: David Widmann --- src/samplers.jl | 1 + src/samplers/expgamma.jl | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 src/samplers/expgamma.jl diff --git a/src/samplers.jl b/src/samplers.jl index 794f2bff41..988fe5cd22 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -20,6 +20,7 @@ for fname in ["aliastable.jl", "poisson.jl", "exponential.jl", "gamma.jl", + "expgamma.jl", "multinomial.jl", "vonmises.jl", "vonmisesfisher.jl", diff --git a/src/samplers/expgamma.jl b/src/samplers/expgamma.jl new file mode 100644 index 0000000000..8d6ca3bbd8 --- /dev/null +++ b/src/samplers/expgamma.jl @@ -0,0 +1,22 @@ +# These are used to bypass subnormals when sampling from + +# Inverse Power sampler +# uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1 +struct ExpGammaIPSampler{S<:Sampleable{Univariate,Continuous},T<:Real} <: Sampleable{Univariate,Continuous} + s::S #sampler for Gamma(1+shape,scale) + nia::T #-1/scale +end + +ExpGammaIPSampler(d::Gamma) = ExpGammaIPSampler(d, GammaMTSampler) +function ExpGammaIPSampler(d::Gamma, ::Type{S}) where {S<:Sampleable} + shape_d = shape(d) + sampler = S(Gamma{partype(d)}(1 + shape_d, scale(d))) + return GammaIPSampler(sampler, -inv(shape_d)) +end + +function rand(rng::AbstractRNG, s::ExpGammaIPSampler) + x = log(rand(rng, s.s)) + e = randexp(rng) + return muladd(s.nia, e, x) +end + From 1ae621062525a8e9fbcdba84ecd3d1e4a12c4c23 Mon Sep 17 00:00:00 2001 From: quildtide <42811940+quildtide@users.noreply.github.com> Date: Wed, 4 Sep 2024 00:42:29 -0400 Subject: [PATCH 5/9] Implement ExpGammaSSSampler Co-Authored-By: chelate <42802644+chelate@users.noreply.github.com> --- src/samplers/expgamma.jl | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/samplers/expgamma.jl b/src/samplers/expgamma.jl index 8d6ca3bbd8..5d5b8d1637 100644 --- a/src/samplers/expgamma.jl +++ b/src/samplers/expgamma.jl @@ -20,3 +20,39 @@ function rand(rng::AbstractRNG, s::ExpGammaIPSampler) return muladd(s.nia, e, x) end + +# Small Shape sampler +# From Liu, C., Martin, R., and Syring, N. (2015) for when shape < 0.3 +struct ExpGammaSSSampler{T<:Real} <: Sampleable{Univariate,Continuous} + α::T + θ::T + λ::T + ω::T + ωω::T + η::T +end + +function ExpGammaSSSampler(d::Gamma) + α = shape(d) + ω = α / MathConstants.e / (1 - α) + return ExpGammaSSSampler(promote( + α, + scale(d), + inv(α) - 1, + ω, + inv(ω + 1) + )) +end + +function rand(rng::AbstractRNG, s::ExpGammaSSSampler) + while true + U = rand(rng) + z = (U <= s.ωω) ? -log(U / s.ωω) : log(rand(rng)) / s.λ + h = exp(-z - exp(-z / α)) + η = z >= 0 ? exp(-z) : s.ω * s.λ * exp(s.λ * z) + if h / η > rand(rng) + return z / α + end + end +end + From 06d81729db1889ba6f68fd37dbf96c0cb1ea5b3b Mon Sep 17 00:00:00 2001 From: quildtide <42811940+quildtide@users.noreply.github.com> Date: Wed, 4 Sep 2024 02:33:01 -0400 Subject: [PATCH 6/9] Implement improved Dirichlet rand --- src/multivariate/dirichlet.jl | 80 +++++++++++++++++++--------------- src/samplers/expgamma.jl | 46 +++++++++++++++---- test/multivariate/dirichlet.jl | 6 ++- 3 files changed, 87 insertions(+), 45 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index fc49bac3e2..04210f2cbc 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -154,49 +154,59 @@ end # sampling -function _rand_handle_overflow!( - rng::AbstractRNG, - d::Union{Dirichlet,DirichletCanon}, - x::AbstractVector{<:Real} - ) - Σ = sum(x) - if Σ == 0.0 - # Distribution behavior approaches categorical as Σα -> 0 - α = d.alpha - iΣα = inv(d.alpha0) - if isinf(iΣα) - # Dirichlet with ALL deeply subnormal parameters - α .*= floatmax(eltype(α)) - iΣα = inv(sum(α)) - end - x[rand(rng, Categorical(iΣα .* α))] = 1 - return x - end +function _rand!(rng::AbstractRNG, + d::Union{Dirichlet,DirichletCanon}, + x::AbstractVector{E}) where {E<:Real} - iΣ = inv(Σ) - if isinf(iΣ) - # Σ is deep subnormal - x .*= floatmax(eltype(x)) - iΣ = inv(sum(x)) - end + if any(a -> a > one(partype(d)), d.alpha) + for (i, αi) in zip(eachindex(x), d.alpha) + @inbounds x[i] = rand(rng, Gamma(αi)) + end - lmul!(iΣ, x) # this returns x -end + return lmul!(inv(sum(x)), x) + else + # Sample in log-space to lower underflow risk + for (i, αi) in zip(eachindex(x), d.alpha) + @inbounds x[i] = _logrand(rng, Gamma(αi)) + end -function _rand!(rng::AbstractRNG, - d::Union{Dirichlet,DirichletCanon}, - x::AbstractVector{<:Real}) - for (i, αi) in zip(eachindex(x), d.alpha) - @inbounds x[i] = rand(rng, Gamma(αi)) + if all(isinf, x) + # Final fallback, parameters likely deeply subnormal + # Distribution behavior approaches categorical as Σα -> 0 + p = copy(d.alpha) + p .*= floatmax(eltype(p)) # rescale to non-subnormal + x .= zero(E) + x[rand(rng, Categorical(inv(sum(p)) .* p))] = one(E) + return x + end + + return softmax!(x) end - _rand_handle_overflow!(rng, d, x) end function _rand!(rng::AbstractRNG, d::Dirichlet{T,<:FillArrays.AbstractFill{T}}, - x::AbstractVector{<:Real}) where {T<:Real} - rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) - _rand_handle_overflow!(rng, d, x) + x::AbstractVector{E}) where {T<:Real, E<:Real} + + if FillArrays.getindex_value(d.alpha) > one(T) + rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) + return lmul!(inv(sum(x)), x) + else + # Sample in log-space to lower underflow risk + _logrand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) + + if all(isinf, x) + # Final fallback, parameters likely deeply subnormal + # Distribution behavior approaches categorical as Σα -> 0 + n = length(d.alpha) + p = Fill(inv(n), n) + x .= zero(E) + x[rand(rng, Categorical(p))] = one(E) + return x + end + + return softmax!(x) + end end ####################################### diff --git a/src/samplers/expgamma.jl b/src/samplers/expgamma.jl index 5d5b8d1637..f4674e3810 100644 --- a/src/samplers/expgamma.jl +++ b/src/samplers/expgamma.jl @@ -29,7 +29,6 @@ struct ExpGammaSSSampler{T<:Real} <: Sampleable{Univariate,Continuous} λ::T ω::T ωω::T - η::T end function ExpGammaSSSampler(d::Gamma) @@ -41,18 +40,47 @@ function ExpGammaSSSampler(d::Gamma) inv(α) - 1, ω, inv(ω + 1) - )) + )...) end -function rand(rng::AbstractRNG, s::ExpGammaSSSampler) +function rand(rng::AbstractRNG, s::ExpGammaSSSampler{T})::Float64 where T + flT = float(T) while true - U = rand(rng) - z = (U <= s.ωω) ? -log(U / s.ωω) : log(rand(rng)) / s.λ - h = exp(-z - exp(-z / α)) - η = z >= 0 ? exp(-z) : s.ω * s.λ * exp(s.λ * z) - if h / η > rand(rng) - return z / α + U = rand(rng, flT) + z = (U <= s.ωω) ? -log(U / s.ωω) : log(rand(rng, flT)) / s.λ + h = exp(-z - exp(-z / s.α)) + η = z >= zero(T) ? exp(-z) : s.ω * s.λ * exp(s.λ * z) + if h / η > rand(rng, flT) + return s.θ - z / s.α end end end + +function _logsampler(d::Gamma) + if shape(d) < 0.3 + return ExpGammaSSSampler(d) + else + return ExpGammaIPSampler(d) + end +end + +function _logrand(rng::AbstractRNG, d::Gamma) + if shape(d) < 0.3 + return rand(rng, ExpGammaSSSampler(d)) + else + return rand(rng, ExpGammaIPSampler(d)) + end +end + +function _logrand!(rng::AbstractRNG, d::Gamma, A::AbstractArray{<:Real}) + if shape(d) < 0.3 + @inbounds for i in eachindex(A) + A[i] = rand(rng, ExpGammaSSSampler(d)) + end + else + @inbounds for i in eachindex(A) + A[i] = rand(rng, ExpGammaIPSampler(d)) + end + end +end diff --git a/test/multivariate/dirichlet.jl b/test/multivariate/dirichlet.jl index 7797888866..6da49bf1a1 100644 --- a/test/multivariate/dirichlet.jl +++ b/test/multivariate/dirichlet.jl @@ -173,8 +173,12 @@ end end for (d, μ) in [ # Subnormal params cause mean(d) to error + + (Dirichlet([5e-310, 5e-310, 5e-310]), [1/3, 1/3, 1/3]), + (Dirichlet(FillArrays.Fill(5e-310, 3)), [1/3, 1/3, 1/3]), (Dirichlet([5e-321, 1e-321, 4e-321]), [.5, .1, .4]), - (Dirichlet([1e-321, 2e-321, 3e-321, 4e-321]), [.1, .2, .3, .4]) + (Dirichlet([1e-321, 2e-321, 3e-321, 4e-321]), [.1, .2, .3, .4]), + (Dirichlet(FillArrays.Fill(1e-321, 4)), [.25, .25, .25, .25]) ] x = rand(d, 10^6) @test mean(x, dims = 2) ≈ μ atol=0.01 From f50e8c8ffe39a84a1305770682e88b91eac75fd5 Mon Sep 17 00:00:00 2001 From: quildtide <42811940+quildtide@users.noreply.github.com> Date: Wed, 4 Sep 2024 02:44:14 -0400 Subject: [PATCH 7/9] Apply #1885 type change to ExpGammaIPSampler --- src/samplers/expgamma.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/samplers/expgamma.jl b/src/samplers/expgamma.jl index f4674e3810..3fefbb3a41 100644 --- a/src/samplers/expgamma.jl +++ b/src/samplers/expgamma.jl @@ -16,7 +16,7 @@ end function rand(rng::AbstractRNG, s::ExpGammaIPSampler) x = log(rand(rng, s.s)) - e = randexp(rng) + e = randexp(rng, typeof(x)) return muladd(s.nia, e, x) end From 0bd5b5cf04a6dba0b4b9ea22b98051749cc2a7d7 Mon Sep 17 00:00:00 2001 From: quildtide <42811940+quildtide@users.noreply.github.com> Date: Wed, 4 Sep 2024 03:50:09 -0400 Subject: [PATCH 8/9] Lower non-log threshold --- src/multivariate/dirichlet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 04210f2cbc..9750d40478 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -158,7 +158,7 @@ function _rand!(rng::AbstractRNG, d::Union{Dirichlet,DirichletCanon}, x::AbstractVector{E}) where {E<:Real} - if any(a -> a > one(partype(d)), d.alpha) + if any(a -> a >= .5, d.alpha) for (i, αi) in zip(eachindex(x), d.alpha) @inbounds x[i] = rand(rng, Gamma(αi)) end @@ -188,7 +188,7 @@ function _rand!(rng::AbstractRNG, d::Dirichlet{T,<:FillArrays.AbstractFill{T}}, x::AbstractVector{E}) where {T<:Real, E<:Real} - if FillArrays.getindex_value(d.alpha) > one(T) + if FillArrays.getindex_value(d.alpha) >= 0.5 rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) return lmul!(inv(sum(x)), x) else From fbc6763beb6d03198b9df0a26d751ead1a24ab7c Mon Sep 17 00:00:00 2001 From: quildtide <42811940+quildtide@users.noreply.github.com> Date: Wed, 26 Mar 2025 00:42:46 -0400 Subject: [PATCH 9/9] Note origin of thresholds --- src/multivariate/dirichlet.jl | 4 ++++ src/samplers/expgamma.jl | 3 +++ 2 files changed, 7 insertions(+) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 9750d40478..7267b96aa8 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -159,6 +159,8 @@ function _rand!(rng::AbstractRNG, x::AbstractVector{E}) where {E<:Real} if any(a -> a >= .5, d.alpha) + # 0.5 is a placeholder; optimal value unknown + # 1 is known to be too high. for (i, αi) in zip(eachindex(x), d.alpha) @inbounds x[i] = rand(rng, Gamma(αi)) end @@ -189,6 +191,8 @@ function _rand!(rng::AbstractRNG, x::AbstractVector{E}) where {T<:Real, E<:Real} if FillArrays.getindex_value(d.alpha) >= 0.5 + # 0.5 is a placeholder; optimal value unknown + # 1 is known to be too high. rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) return lmul!(inv(sum(x)), x) else diff --git a/src/samplers/expgamma.jl b/src/samplers/expgamma.jl index 3fefbb3a41..788ee6d851 100644 --- a/src/samplers/expgamma.jl +++ b/src/samplers/expgamma.jl @@ -59,8 +59,11 @@ end function _logsampler(d::Gamma) if shape(d) < 0.3 + # Liu, Martin, and Syring recommend 0.3 as a cutoff to switch + # to Kundu-Gupta, but we have not implemented Kundu-Gupta yet. return ExpGammaSSSampler(d) else + # TODO: Kundu-Gupta algo. #3 for performance reasons? return ExpGammaIPSampler(d) end end