diff --git a/src/samplers/vonmisesfisher.jl b/src/samplers/vonmisesfisher.jl index fd3eb2df08..31e6643ec1 100644 --- a/src/samplers/vonmisesfisher.jl +++ b/src/samplers/vonmisesfisher.jl @@ -1,34 +1,50 @@ # Sampler for von Mises-Fisher +# Ref https://doi.org/10.18637/jss.v058.i10 +# Ref https://hal.science/hal-04004568v3 struct VonMisesFisherSampler <: Sampleable{Multivariate,Continuous} p::Int # the dimension κ::Float64 b::Float64 x0::Float64 c::Float64 + τ::Float64 v::Vector{Float64} end function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64) + # Step 1: Calculate b, x₀, and c p = length(μ) - b = _vmf_bval(p, κ) + b = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1))) x0 = (1.0 - b) / (1.0 + b) c = κ * x0 + (p - 1) * log1p(-abs2(x0)) - v = _vmf_householder_vec(μ) - VonMisesFisherSampler(p, κ, b, x0, c, v) + + # Compute Householder transformation: + # `LinearAlgebra.reflector!` computes a Householder transformation H such that + # H μ = -copysign(|μ|₂, μ[1]) e₁ + # μ is a unit vector, and hence this implies that + # H e₁ = μ if μ[1] < 0 and H (-e₁) = μ otherwise + # Since `v[1] = flipsign(1, μ[1])`, the sign of `μ[1]` can be extracted from `v[1]` during sampling + v = similar(μ) + copyto!(v, μ) + τ = LinearAlgebra.reflector!(v) + + return VonMisesFisherSampler(p, κ, b, x0, c, τ, v) end Base.length(s::VonMisesFisherSampler) = length(s.v) -@inline function _vmf_rot!(v::AbstractVector, x::AbstractVector) - # rotate - scale = 2.0 * (v' * x) - @. x -= (scale * v) - return x -end +function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector{<:Real}) + # TODO: Generalize to more general indices + Base.require_one_based_indexing(x) + # Sample angle `w` assuming mean direction `(1, 0, ..., 0)` + w = _vmf_angle(rng, spl) + + # Transform to sample for mean direction `(flipsign(1.0, μ[1]), 0, ..., 0)` + v = spl.v + w = flipsign(w, v[1]) -function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) - w = _vmf_genw(rng, spl) + # Generate sample assuming mean direction `(flipsign(1.0, μ[1]), 0, ..., 0)` p = spl.p x[1] = w s = 0.0 @@ -43,60 +59,75 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) x[i] *= r end - return _vmf_rot!(spl.v, x) + # Apply Householder transformation to mean direction `μ` + return LinearAlgebra.reflectorApply!(v, spl.τ, x) end ### Core computation -_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1))) - -function _vmf_genw3(rng::AbstractRNG, p, b, x0, c, κ) - ξ = rand(rng) - w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) - return w::Float64 -end - -function _vmf_genwp(rng::AbstractRNG, p, b, x0, c, κ) - r = (p - 1) / 2.0 - betad = Beta(r, r) - z = rand(rng, betad) - w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z) - while κ * w + (p - 1) * log(1 - x0 * w) - c < log(rand(rng)) - z = rand(rng, betad) - w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z) - end - return w::Float64 -end +# Step 2: Sample angle W +function _vmf_angle(rng::AbstractRNG, spl::VonMisesFisherSampler) + p = spl.p + κ = spl.κ -# generate the W value -- the key step in simulating vMF -# -# following movMF's document for the p != 3 case -# and Wenzel Jakob's document for the p == 3 case -function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) if p == 3 - return _vmf_genw3(rng, p, b, x0, c, κ) + _vmf_angle3(rng, κ) else - return _vmf_genwp(rng, p, b, x0, c, κ) + # General case: Rejection sampling + # Ref https://doi.org/10.18637/jss.v058.i10 + b = spl.b + c = spl.c + p = spl.p + κ = spl.κ + x0 = spl.x0 + pm1 = p - 1 + + if p == 2 + # In this case the distribution reduces to the von Mises distribution on the circle + # We exploit the fact that `Beta(1/2, 1/2) = Arcsine(0, 1)` + dist = Arcsine(zero(b), one(b)) + while true + z = rand(rng, dist) + w = (1 - (1 + b) * z) / (1 - (1 - b) * z) + if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng) + return w::Float64 + end + end + else + # We sample from a `Beta((p - 1)/2, (p - 1)/2)` distribution, possibly repeatedly + # Therefore we construct a sampler + # To avoid the type instability of `sampler(Beta(...))` and `sampler(Gamma(...))` + # we directly construct the Gamma sampler for Gamma((p - 1)/2, 1) + # Since (p - 1)/2 > 1, we construct a `GammaMTSampler` + r = pm1 / 2 + gammasampler = GammaMTSampler(Gamma{typeof(r)}(r, one(r))) + while true + # w is supposed to be generated as + # z ~ Beta((p - 1)/ 2, (p - 1)/2) + # w = (1 - (1 + b) * z) / (1 - (1 - b) * z) + # We sample z as + # z1 ~ Gamma((p - 1) / 2, 1) + # z2 ~ Gamma((p - 1) / 2, 1) + # z = z1 / (z1 + z2) + # and rewrite the expression for w + # Cf. case p == 2 above + z1 = rand(rng, gammasampler) + z2 = rand(rng, gammasampler) + b_z1 = b * z1 + w = (z2 - b_z1) / (z2 + b_z1) + if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng) + return w::Float64 + end + end + end end end - -_vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) = - _vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ) - -function _vmf_householder_vec(μ::Vector{Float64}) - # assuming μ is a unit-vector (which it should be) - # can compute v in a single pass over μ - - p = length(μ) - v = similar(μ) - v[1] = μ[1] - 1.0 - s = sqrt(-2*v[1]) - v[1] /= s - - @inbounds for i in 2:p - v[i] = μ[i] / s - end - - return v +# Special case: 2-sphere +@inline function _vmf_angle3(rng::AbstractRNG, κ::Real) + # In this case, we can directly sample the angle + # Ref https://www.mitsuba-renderer.org/~wenzel/files/vmf.pdf + ξ = rand(rng) + w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) + return w::Float64 end diff --git a/src/univariate/continuous/gamma.jl b/src/univariate/continuous/gamma.jl index 866255fb7d..8ba207d2c7 100644 --- a/src/univariate/continuous/gamma.jl +++ b/src/univariate/continuous/gamma.jl @@ -105,7 +105,6 @@ function rand(rng::AbstractRNG, d::Gamma) # TODO: shape(d) = 0.5 : use scaled chisq return rand(rng, GammaIPSampler(d)) elseif shape(d) == 1.0 - θ = return rand(rng, Exponential{partype(d)}(scale(d))) else return rand(rng, GammaMTSampler(d)) diff --git a/test/multivariate/vonmisesfisher.jl b/test/multivariate/vonmisesfisher.jl index cc45f41ed5..c4f8b2859d 100644 --- a/test/multivariate/vonmisesfisher.jl +++ b/test/multivariate/vonmisesfisher.jl @@ -22,28 +22,7 @@ function gen_vmf_tdata(n::Int, p::Int, return X end -function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing) - if ismissing(rng) - μ = randn(p) - x = randn(p) - else - μ = randn(rng, p) - x = randn(rng, p) - end - κ = norm(μ) - μ = μ ./ κ - - s = Distributions.VonMisesFisherSampler(μ, κ) - v = μ - vcat(1, zeros(p-1)) - H = I - 2*v*v'/(v'*v) - - @test Distributions._vmf_rot!(s.v, copy(x)) ≈ (H*x) - -end - - - -function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing) +function test_angle3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing) p = 3 if ismissing(rng) @@ -53,21 +32,20 @@ function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missin end μ = μ ./ norm(μ) - s = Distributions.VonMisesFisherSampler(μ, float(κ)) + spl = Distributions.VonMisesFisherSampler(μ, float(κ)) + angle3_res = [Distributions._vmf_angle3(rng, spl.κ) for _ in 1:ns] + angle_res = [Distributions._vmf_angle(rng, spl) for _ in 1:ns] - genw3_res = [Distributions._vmf_genw3(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns] - genwp_res = [Distributions._vmf_genwp(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns] - - @test isapprox(mean(genw3_res), mean(genwp_res), atol=0.01) - @test isapprox(std(genw3_res), std(genwp_res), atol=0.01/κ) + @test mean(angle3_res) ≈ mean(angle_res) rtol=5e-2 + @test std(angle3_res) ≈ std(angle_res) rtol=1e-2 # test mean and stdev against analytical formulas coth_κ = coth(κ) mean_w = coth_κ - 1/κ var_w = 1 - coth_κ^2 + 1/κ^2 - @test isapprox(mean(genw3_res), mean_w, atol=0.01) - @test isapprox(std(genw3_res), sqrt(var_w), atol=0.01/κ) + @test mean(angle3_res) ≈ mean_w rtol=5e-2 + @test std(angle3_res) ≈ sqrt(var_w) rtol=1e-2 end @@ -173,12 +151,21 @@ ns = 10^6 (2, 2), (2, 1000)] # test with large κ test_vonmisesfisher(p, κ, n, ns, rng) - test_vmf_rot(p, rng) end if !ismissing(rng) @testset "Testing genw with $key at (3, $κ)" for κ in [0.1, 0.5, 1.0, 2.0, 5.0] - test_genw3(κ, ns, rng) + test_angle3(κ, ns, rng) end end end + +# issue #1423 +@testset "Special case: No rotation" begin + for n in 2:10 + d = VonMisesFisher(vcat(1, zeros(n - 1)), 1.0) + @test sum(abs2, rand(d)) ≈ 1 + d_est = fit_mle(VonMisesFisher, rand(d, 100_000)) + @test meandir(d_est) ≈ meandir(d) rtol=5e-2 + end +end