Skip to content

Dispatch for drawing multiples #1985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,22 @@ rand(rng::AbstractRNG, s::MixtureSampler{Univariate}) =
rand(rng::AbstractRNG, d::MixtureModel{Univariate}) =
rand(rng, component(d, rand(rng, d.prior)))

function rand(rng::AbstractRNG, d::MixtureModel{Univariate}, n::Int)
counts = rand(rng, Multinomial(n, probs(d.prior)))
x = Vector{eltype(d)}(undef, n)
offset = 0
for i in eachindex(counts)
ni = counts[i]
if ni > 0
c = component(d, i)
v = view(x, (offset+1):(offset+ni))
v .= rand(rng, c, ni)
offset += ni
end
end
return shuffle!(rng, x)
end

# multivariate mixture sampler for a vector
_rand!(rng::AbstractRNG, s::MixtureSampler{Multivariate}, x::AbstractVector{<:Real}) =
@inbounds rand!(rng, s.csamplers[rand(rng, s.psampler)], x)
Expand Down
53 changes: 53 additions & 0 deletions src/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,17 @@ end

## random number generation


"""
rand(rng::AbstractRNG, d::Truncated)

Generate a single random sample from a truncated distribution.

The sampling strategy depends on the probability mass of the truncated region (`tp`):
- If `tp > 0.25`, rejection sampling is used. This is efficient when the truncated region covers a large portion of the original distribution.
- If `sqrt(eps) < tp <= 0.25`, inverse transform sampling is used. This is more efficient for smaller truncated regions.
- If `tp` is very small (`<= sqrt(eps)`), a numerically stable version of inverse transform sampling is used which performs calculations in log-space to maintain precision.
"""
function rand(rng::AbstractRNG, d::Truncated)
d0 = d.untruncated
tp = d.tp
Expand All @@ -233,6 +244,48 @@ function rand(rng::AbstractRNG, d::Truncated)
end
end


"""
rand(rng::AbstractRNG, d::Truncated, n::Int)

Generate `n` random samples from a truncated distribution.

The implementation samples the untruncated distribution, `d0` with `rand(rng, d0, n)` in batches and only keeps the samples that fall within the truncated range. The size of the batches is adaptively estimated to reduce the number of iterations.

See [rand(rng::AbstractRNG, d::Truncated)](@ref) that handles the case of small mass of the truncated region.

!!! warning
This method can be inefficient if the probability mass of the truncated region is very small.
"""
function rand(rng::AbstractRNG, d::Truncated, n::Int)
n == 0 && return eltype(d)[]
#
d0 = d.untruncated
tp = d.tp
lower = d.lower
upper = d.upper
# Preallocate samples array
samples = Vector{eltype(d)}(undef, n)
n_collected = 0
while n_collected < n
n_remaining = n - n_collected
# Estimate number of samples to draw from the untruncated distribution.
# We draw more to reduce the chance of needing more rounds.
n_expected = n_remaining / tp
δn_expected = sqrt(n_remaining * tp * (1 - tp)) # standard deviation of the expected number
n_batch = ceil(Int, n_expected + 3δn_expected)
samples_d0 = rand(rng, d0, n_batch)
for s in samples_d0
if _in_closed_interval(s, lower, upper)
n_collected += 1
samples[n_collected] = s
n_collected == n && break
end
end
end
return samples
end

## show

function show(io::IO, d::Truncated)
Expand Down
Loading