diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index 7f3664290..38e9d59e5 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -477,6 +477,34 @@ rand(rng::AbstractRNG, s::MixtureSampler{Univariate}) = rand(rng::AbstractRNG, d::MixtureModel{Univariate}) = rand(rng, component(d, rand(rng, d.prior))) +function rand(rng::AbstractRNG, d::MixtureModel{Univariate}, n::Int) + counts = rand(rng, Multinomial(n, probs(d.prior))) + + # Find the component with the maximum count to minimize resizing + max_count_idx = argmax(counts) + max_count = counts[max_count_idx] + + # Sample from the component with maximum count first and use it directly + x = rand(rng, component(d, max_count_idx), max_count) + + # Resize to the full size and continue with other components + resize!(x, n) + offset = max_count + + for i in eachindex(counts) + if i != max_count_idx + ni = counts[i] + if ni > 0 + c = component(d, i) + last_offset = offset + ni - 1 + rand!(rng, c, @view(x[(begin+offset):(begin+last_offset)])) + offset = last_offset + 1 + end + end + end + return shuffle!(rng, x) +end + # multivariate mixture sampler for a vector _rand!(rng::AbstractRNG, s::MixtureSampler{Multivariate}, x::AbstractVector{<:Real}) = @inbounds rand!(rng, s.csamplers[rand(rng, s.psampler)], x) diff --git a/src/truncate.jl b/src/truncate.jl index 48d62b015..f0a31ddc6 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -233,6 +233,64 @@ function rand(rng::AbstractRNG, d::Truncated) end end +function rand(rng::AbstractRNG, d::Truncated, n::Int) + n == 0 && return rand(d.untruncated, 0) + + d0 = d.untruncated + tp = d.tp + lower = d.lower + upper = d.upper + + # Use the same three regimes as the scalar version + if tp > 0.25 + # Regime 1: Rejection sampling with batch optimization + # Get the correct type and memory by sampling from the untruncated distribution + samples = rand(rng, d0, n) + n_collected = 0 + max_batch = 0 + batch_buffer = Vector{eltype(samples)}() + while n_collected < n + n_remaining = n - n_collected + n_expected = n_remaining / tp + δn_expected = sqrt(n_remaining * tp * (1 - tp)) + n_batch_f = n_expected + 3δn_expected + n_batch = ceil(Int, n_batch_f) + if n_batch > max_batch + resize!(batch_buffer, n_batch) + max_batch = n_batch + end + rand!(rng, d0, batch_buffer) + for i in 1:n_batch + s = batch_buffer[i] + if _in_closed_interval(s, lower, upper) + n_collected += 1 + samples[n_collected] = s + n_collected == n && break + end + end + end + return samples + elseif tp > sqrt(eps(typeof(float(tp)))) + # Regime 2: Quantile-based sampling + # Sample one value first to determine the correct type + sample_type = typeof(quantile(d0, d.lcdf + rand(rng) * d.tp)) + samples = Vector{sample_type}(undef, n) + for i in 1:n + samples[i] = quantile(d0, d.lcdf + rand(rng) * d.tp) + end + return samples + else + # Regime 3: Log-space computation + # Sample one value first to determine the correct type + sample_type = typeof(invlogcdf(d0, logaddexp(d.loglcdf, d.logtp - randexp(rng)))) + samples = Vector{sample_type}(undef, n) + for i in 1:n + samples[i] = invlogcdf(d0, logaddexp(d.loglcdf, d.logtp - randexp(rng))) + end + return samples + end +end + ## show function show(io::IO, d::Truncated)