From d30e355a77117d07a82456826ae78719b35f1cbf Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Tue, 17 Jun 2025 22:47:31 +0200 Subject: [PATCH 01/13] Add dispatch for drawing multiple samples from UnivariateMixtureModel --- src/mixtures/mixturemodel.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index 7f3664290..c0f2c70cb 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -477,6 +477,22 @@ rand(rng::AbstractRNG, s::MixtureSampler{Univariate}) = rand(rng::AbstractRNG, d::MixtureModel{Univariate}) = rand(rng, component(d, rand(rng, d.prior))) +function rand(rng::AbstractRNG, d::MixtureModel{Univariate}, n::Int) + counts = rand(rng, Multinomial(n, probs(d.prior))) + x = Vector{eltype(d)}(undef, n) + offset = 0 + for i in eachindex(counts) + ni = counts[i] + if ni > 0 + c = component(d, i) + v = view(x, (offset+1):(offset+ni)) + v .= rand(rng, c, ni) + offset += ni + end + end + return shuffle!(rng, x) +end + # multivariate mixture sampler for a vector _rand!(rng::AbstractRNG, s::MixtureSampler{Multivariate}, x::AbstractVector{<:Real}) = @inbounds rand!(rng, s.csamplers[rand(rng, s.psampler)], x) From 4aef8904315ec5854cd876fa6ff2f6d27ea7c31e Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Tue, 17 Jun 2025 23:04:36 +0200 Subject: [PATCH 02/13] first implementation of rand-n for truncated --- src/truncate.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/truncate.jl b/src/truncate.jl index 48d62b015..03992eb89 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -233,6 +233,25 @@ function rand(rng::AbstractRNG, d::Truncated) end end + +function rand(rng::AbstractRNG, d::Truncated, n::Int) + d0 = d.untruncated + tp = d.tp + lower = d.lower + upper = d.upper + # Correct for rejected samples + n_corrected = round(Int, n / tp) + n_gen = n_corrected + 3 * round(Int, sqrt(n_corrected)) + # + sample = rand(rng, d0, n_gen) + filter!(sample) do r + _in_closed_interval(r, lower, upper) + end + length(sample) > n_corrected && return sample[1:n_corrected] + # If we didn't get enough samples, generate more + return rand(rng, d, n) # try again +end + ## show function show(io::IO, d::Truncated) From 3fc31cef8ae7621f7e888a4c045f97188610311f Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Tue, 17 Jun 2025 23:38:07 +0200 Subject: [PATCH 03/13] no-recursion in implementation of rand-n for truncated --- src/truncate.jl | 42 ++++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/src/truncate.jl b/src/truncate.jl index 03992eb89..6620b1779 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -234,22 +234,44 @@ function rand(rng::AbstractRNG, d::Truncated) end +""" + rand(rng::AbstractRNG, d::Truncated, n::Int) + +Generate `n` random samples from a truncated distribution. + +The implementation uses rejection sampling. It draws samples from the untruncated distribution in batches and only keeps the samples that fall within the truncated range. +The size of the batches is adaptively estimated to reduce the number of iterations. + +!!! warning + This method can be inefficient if the probability mass of the truncated region is very small. +""" function rand(rng::AbstractRNG, d::Truncated, n::Int) + n == 0 && return eltype(d)[] + # d0 = d.untruncated tp = d.tp lower = d.lower upper = d.upper - # Correct for rejected samples - n_corrected = round(Int, n / tp) - n_gen = n_corrected + 3 * round(Int, sqrt(n_corrected)) - # - sample = rand(rng, d0, n_gen) - filter!(sample) do r - _in_closed_interval(r, lower, upper) + # Preallocate samples array + samples = Vector{eltype(d)}(undef, n) + n_collected = 0 + while n_collected < n + n_remaining = n - n_collected + # Estimate number of samples to draw from the untruncated distribution. + # We draw more to reduce the chance of needing more rounds. + n_expected = n_remaining / tp + δn_expected = sqrt(n_remaining * tp * (1 - tp)) # standard deviation of the expected number + n_batch = ceil(Int, n_expected + 3δn_expected) + samples_d0 = rand(rng, d0, n_batch) + for s in samples_d0 + if _in_closed_interval(s, lower, upper) + n_collected += 1 + samples[n_collected] = s + n_collected == n && break + end + end end - length(sample) > n_corrected && return sample[1:n_corrected] - # If we didn't get enough samples, generate more - return rand(rng, d, n) # try again + return samples end ## show From 5f406ab8c93aeb7e7b0545818309448bb2618d42 Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Tue, 17 Jun 2025 23:46:19 +0200 Subject: [PATCH 04/13] docstring for existing rand --- src/truncate.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/truncate.jl b/src/truncate.jl index 6620b1779..4b361e25f 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -213,6 +213,17 @@ end ## random number generation + +""" + rand(rng::AbstractRNG, d::Truncated) + +Generate a single random sample from a truncated distribution. + +The sampling strategy depends on the probability mass of the truncated region (`tp`): +- If `tp > 0.25`, rejection sampling is used. This is efficient when the truncated region covers a large portion of the original distribution. +- If `sqrt(eps) < tp <= 0.25`, inverse transform sampling is used. This is more efficient for smaller truncated regions. +- If `tp` is very small (`<= sqrt(eps)`), a numerically stable version of inverse transform sampling is used which performs calculations in log-space to maintain precision. +""" function rand(rng::AbstractRNG, d::Truncated) d0 = d.untruncated tp = d.tp From c10ef3ef49f6f635dfd0c320f0f2e3f32a559010 Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Tue, 17 Jun 2025 23:46:42 +0200 Subject: [PATCH 05/13] minor update to rand-n truncated docstring --- src/truncate.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/truncate.jl b/src/truncate.jl index 4b361e25f..423607cd0 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -250,11 +250,12 @@ end Generate `n` random samples from a truncated distribution. -The implementation uses rejection sampling. It draws samples from the untruncated distribution in batches and only keeps the samples that fall within the truncated range. -The size of the batches is adaptively estimated to reduce the number of iterations. +The implementation samples the untruncated distribution, `d0` with `rand(rng, d0, n)` in batches and only keeps the samples that fall within the truncated range. The size of the batches is adaptively estimated to reduce the number of iterations. + +See [rand(rng::AbstractRNG, d::Truncated)](@ref) that handles the case of small mass of the truncated region. !!! warning - This method can be inefficient if the probability mass of the truncated region is very small. + This method can be inefficient if the probability mass of the truncated region is very small. """ function rand(rng::AbstractRNG, d::Truncated, n::Int) n == 0 && return eltype(d)[] From df17b509ab9090ff418d236e5615e114f26ec35c Mon Sep 17 00:00:00 2001 From: Misha Mikhasenko Date: Wed, 18 Jun 2025 15:41:45 +0200 Subject: [PATCH 06/13] Update src/mixtures/mixturemodel.jl Co-authored-by: David Widmann --- src/mixtures/mixturemodel.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index c0f2c70cb..9165f45fe 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -485,9 +485,9 @@ function rand(rng::AbstractRNG, d::MixtureModel{Univariate}, n::Int) ni = counts[i] if ni > 0 c = component(d, i) - v = view(x, (offset+1):(offset+ni)) - v .= rand(rng, c, ni) - offset += ni + last_offset = offset + ni - 1 + rand!(rng, c, @view(x[(begin + offset):(begin + last_offset)])) + offset = last_offset + 1 end end return shuffle!(rng, x) From 7501a5aa67497262276e6505ae131b54f99ff996 Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Wed, 18 Jun 2025 16:00:59 +0200 Subject: [PATCH 07/13] eltype to partype --- src/mixtures/mixturemodel.jl | 6 +++--- src/truncate.jl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index 9165f45fe..284792319 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -479,14 +479,14 @@ rand(rng::AbstractRNG, d::MixtureModel{Univariate}) = function rand(rng::AbstractRNG, d::MixtureModel{Univariate}, n::Int) counts = rand(rng, Multinomial(n, probs(d.prior))) - x = Vector{eltype(d)}(undef, n) + x = Vector{partype(d)}(undef, n) offset = 0 for i in eachindex(counts) ni = counts[i] if ni > 0 c = component(d, i) - last_offset = offset + ni - 1 - rand!(rng, c, @view(x[(begin + offset):(begin + last_offset)])) + last_offset = offset + ni - 1 + rand!(rng, c, @view(x[(begin+offset):(begin+last_offset)])) offset = last_offset + 1 end end diff --git a/src/truncate.jl b/src/truncate.jl index 423607cd0..b2c25c65e 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -258,14 +258,14 @@ See [rand(rng::AbstractRNG, d::Truncated)](@ref) that handles the case of small This method can be inefficient if the probability mass of the truncated region is very small. """ function rand(rng::AbstractRNG, d::Truncated, n::Int) - n == 0 && return eltype(d)[] + n == 0 && return partype(d)[] # d0 = d.untruncated tp = d.tp lower = d.lower upper = d.upper # Preallocate samples array - samples = Vector{eltype(d)}(undef, n) + samples = Vector{partype(d)}(undef, n) n_collected = 0 while n_collected < n n_remaining = n - n_collected From 832519fb05408633e2cd8d059d6aea83a4a8bac7 Mon Sep 17 00:00:00 2001 From: Misha Mikhasenko Date: Wed, 18 Jun 2025 18:21:18 +0200 Subject: [PATCH 08/13] remove rand truncated docstring Co-authored-by: David Widmann --- src/truncate.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/truncate.jl b/src/truncate.jl index b2c25c65e..a803474da 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -213,17 +213,6 @@ end ## random number generation - -""" - rand(rng::AbstractRNG, d::Truncated) - -Generate a single random sample from a truncated distribution. - -The sampling strategy depends on the probability mass of the truncated region (`tp`): -- If `tp > 0.25`, rejection sampling is used. This is efficient when the truncated region covers a large portion of the original distribution. -- If `sqrt(eps) < tp <= 0.25`, inverse transform sampling is used. This is more efficient for smaller truncated regions. -- If `tp` is very small (`<= sqrt(eps)`), a numerically stable version of inverse transform sampling is used which performs calculations in log-space to maintain precision. -""" function rand(rng::AbstractRNG, d::Truncated) d0 = d.untruncated tp = d.tp From ee17d46ce3b2740b1de18cd60526f7a2f6ad7cb4 Mon Sep 17 00:00:00 2001 From: Misha Mikhasenko Date: Wed, 18 Jun 2025 18:21:49 +0200 Subject: [PATCH 09/13] remove rand-n truncated docstring Co-authored-by: David Widmann --- src/truncate.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/truncate.jl b/src/truncate.jl index a803474da..26e5365b4 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -233,19 +233,6 @@ function rand(rng::AbstractRNG, d::Truncated) end end - -""" - rand(rng::AbstractRNG, d::Truncated, n::Int) - -Generate `n` random samples from a truncated distribution. - -The implementation samples the untruncated distribution, `d0` with `rand(rng, d0, n)` in batches and only keeps the samples that fall within the truncated range. The size of the batches is adaptively estimated to reduce the number of iterations. - -See [rand(rng::AbstractRNG, d::Truncated)](@ref) that handles the case of small mass of the truncated region. - -!!! warning - This method can be inefficient if the probability mass of the truncated region is very small. -""" function rand(rng::AbstractRNG, d::Truncated, n::Int) n == 0 && return partype(d)[] # From 8a9c15af9bcba8dd1f557c3119783156b8a45176 Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Wed, 18 Jun 2025 18:36:08 +0200 Subject: [PATCH 10/13] in-place rand(rng::AbstractRNG, d::Truncated, n::Int) --- src/truncate.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/truncate.jl b/src/truncate.jl index 26e5365b4..e3783fe04 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -243,15 +243,21 @@ function rand(rng::AbstractRNG, d::Truncated, n::Int) # Preallocate samples array samples = Vector{partype(d)}(undef, n) n_collected = 0 + # Preallocate a buffer for batch sampling (size will be adjusted as needed) + max_batch = 0 + batch_buffer = Vector{partype(d)}() while n_collected < n n_remaining = n - n_collected - # Estimate number of samples to draw from the untruncated distribution. - # We draw more to reduce the chance of needing more rounds. n_expected = n_remaining / tp - δn_expected = sqrt(n_remaining * tp * (1 - tp)) # standard deviation of the expected number + δn_expected = sqrt(n_remaining * tp * (1 - tp)) n_batch = ceil(Int, n_expected + 3δn_expected) - samples_d0 = rand(rng, d0, n_batch) - for s in samples_d0 + if n_batch > max_batch + resize!(batch_buffer, n_batch) + max_batch = n_batch + end + rand!(rng, d0, batch_buffer) + for i in 1:n_batch + s = batch_buffer[i] if _in_closed_interval(s, lower, upper) n_collected += 1 samples[n_collected] = s From 4fed7442546e94d1715daf3ed4e168c2073658ed Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Wed, 18 Jun 2025 19:10:46 +0200 Subject: [PATCH 11/13] fix for small trunctated distributions with tiny remaining probability - fallback to rand(d) --- src/truncate.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/truncate.jl b/src/truncate.jl index e3783fe04..b9610720c 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -240,17 +240,25 @@ function rand(rng::AbstractRNG, d::Truncated, n::Int) tp = d.tp lower = d.lower upper = d.upper - # Preallocate samples array samples = Vector{partype(d)}(undef, n) n_collected = 0 - # Preallocate a buffer for batch sampling (size will be adjusted as needed) max_batch = 0 batch_buffer = Vector{partype(d)}() + # If tp is extremely small, fall back to scalar sampling + threshold = 1e7 # maximum batch size allowed while n_collected < n n_remaining = n - n_collected n_expected = n_remaining / tp δn_expected = sqrt(n_remaining * tp * (1 - tp)) - n_batch = ceil(Int, n_expected + 3δn_expected) + n_batch_f = n_expected + 3δn_expected + if !isfinite(n_batch_f) || n_batch_f > threshold + # Fallback: use scalar method for remaining samples + for i in 1:n_remaining + samples[n_collected+i] = rand(rng, d) + end + break + end + n_batch = ceil(Int, n_batch_f) if n_batch > max_batch resize!(batch_buffer, n_batch) max_batch = n_batch From 0d303ba8c45cd75c17129da546fb9f556be2ea2b Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Mon, 23 Jun 2025 16:31:05 +0200 Subject: [PATCH 12/13] switch to three regimes --- src/truncate.jl | 70 ++++++++++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 30 deletions(-) diff --git a/src/truncate.jl b/src/truncate.jl index b9610720c..f30312f32 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -235,45 +235,55 @@ end function rand(rng::AbstractRNG, d::Truncated, n::Int) n == 0 && return partype(d)[] - # + d0 = d.untruncated tp = d.tp lower = d.lower upper = d.upper - samples = Vector{partype(d)}(undef, n) - n_collected = 0 - max_batch = 0 - batch_buffer = Vector{partype(d)}() - # If tp is extremely small, fall back to scalar sampling - threshold = 1e7 # maximum batch size allowed - while n_collected < n - n_remaining = n - n_collected - n_expected = n_remaining / tp - δn_expected = sqrt(n_remaining * tp * (1 - tp)) - n_batch_f = n_expected + 3δn_expected - if !isfinite(n_batch_f) || n_batch_f > threshold - # Fallback: use scalar method for remaining samples - for i in 1:n_remaining - samples[n_collected+i] = rand(rng, d) + + # Use the same three regimes as the scalar version + if tp > 0.25 + # Regime 1: Rejection sampling with batch optimization + samples = Vector{partype(d)}(undef, n) + n_collected = 0 + max_batch = 0 + batch_buffer = Vector{partype(d)}() + while n_collected < n + n_remaining = n - n_collected + n_expected = n_remaining / tp + δn_expected = sqrt(n_remaining * tp * (1 - tp)) + n_batch_f = n_expected + 3δn_expected + n_batch = ceil(Int, n_batch_f) + if n_batch > max_batch + resize!(batch_buffer, n_batch) + max_batch = n_batch + end + rand!(rng, d0, batch_buffer) + for i in 1:n_batch + s = batch_buffer[i] + if _in_closed_interval(s, lower, upper) + n_collected += 1 + samples[n_collected] = s + n_collected == n && break + end end - break end - n_batch = ceil(Int, n_batch_f) - if n_batch > max_batch - resize!(batch_buffer, n_batch) - max_batch = n_batch + return samples + elseif tp > sqrt(eps(typeof(float(tp)))) + # Regime 2: Quantile-based sampling + samples = Vector{partype(d)}(undef, n) + for i in 1:n + samples[i] = quantile(d0, d.lcdf + rand(rng) * d.tp) end - rand!(rng, d0, batch_buffer) - for i in 1:n_batch - s = batch_buffer[i] - if _in_closed_interval(s, lower, upper) - n_collected += 1 - samples[n_collected] = s - n_collected == n && break - end + return samples + else + # Regime 3: Log-space computation + samples = Vector{partype(d)}(undef, n) + for i in 1:n + samples[i] = invlogcdf(d0, logaddexp(d.loglcdf, d.logtp - randexp(rng))) end + return samples end - return samples end ## show From 4534327b5b8a2875157e31923c746eba989c7b2b Mon Sep 17 00:00:00 2001 From: Mikhail Mikhasenko Date: Mon, 23 Jun 2025 21:29:12 +0200 Subject: [PATCH 13/13] rand + resize strategy --- src/mixtures/mixturemodel.jl | 28 ++++++++++++++++++++-------- src/truncate.jl | 15 ++++++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index 284792319..38e9d59e5 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -479,15 +479,27 @@ rand(rng::AbstractRNG, d::MixtureModel{Univariate}) = function rand(rng::AbstractRNG, d::MixtureModel{Univariate}, n::Int) counts = rand(rng, Multinomial(n, probs(d.prior))) - x = Vector{partype(d)}(undef, n) - offset = 0 + + # Find the component with the maximum count to minimize resizing + max_count_idx = argmax(counts) + max_count = counts[max_count_idx] + + # Sample from the component with maximum count first and use it directly + x = rand(rng, component(d, max_count_idx), max_count) + + # Resize to the full size and continue with other components + resize!(x, n) + offset = max_count + for i in eachindex(counts) - ni = counts[i] - if ni > 0 - c = component(d, i) - last_offset = offset + ni - 1 - rand!(rng, c, @view(x[(begin+offset):(begin+last_offset)])) - offset = last_offset + 1 + if i != max_count_idx + ni = counts[i] + if ni > 0 + c = component(d, i) + last_offset = offset + ni - 1 + rand!(rng, c, @view(x[(begin+offset):(begin+last_offset)])) + offset = last_offset + 1 + end end end return shuffle!(rng, x) diff --git a/src/truncate.jl b/src/truncate.jl index f30312f32..f0a31ddc6 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -234,7 +234,7 @@ function rand(rng::AbstractRNG, d::Truncated) end function rand(rng::AbstractRNG, d::Truncated, n::Int) - n == 0 && return partype(d)[] + n == 0 && return rand(d.untruncated, 0) d0 = d.untruncated tp = d.tp @@ -244,10 +244,11 @@ function rand(rng::AbstractRNG, d::Truncated, n::Int) # Use the same three regimes as the scalar version if tp > 0.25 # Regime 1: Rejection sampling with batch optimization - samples = Vector{partype(d)}(undef, n) + # Get the correct type and memory by sampling from the untruncated distribution + samples = rand(rng, d0, n) n_collected = 0 max_batch = 0 - batch_buffer = Vector{partype(d)}() + batch_buffer = Vector{eltype(samples)}() while n_collected < n n_remaining = n - n_collected n_expected = n_remaining / tp @@ -271,14 +272,18 @@ function rand(rng::AbstractRNG, d::Truncated, n::Int) return samples elseif tp > sqrt(eps(typeof(float(tp)))) # Regime 2: Quantile-based sampling - samples = Vector{partype(d)}(undef, n) + # Sample one value first to determine the correct type + sample_type = typeof(quantile(d0, d.lcdf + rand(rng) * d.tp)) + samples = Vector{sample_type}(undef, n) for i in 1:n samples[i] = quantile(d0, d.lcdf + rand(rng) * d.tp) end return samples else # Regime 3: Log-space computation - samples = Vector{partype(d)}(undef, n) + # Sample one value first to determine the correct type + sample_type = typeof(invlogcdf(d0, logaddexp(d.loglcdf, d.logtp - randexp(rng)))) + samples = Vector{sample_type}(undef, n) for i in 1:n samples[i] = invlogcdf(d0, logaddexp(d.loglcdf, d.logtp - randexp(rng))) end