diff --git a/docs/src/extends.md b/docs/src/extends.md index 0c87b8bfbe..50dadd6f02 100644 --- a/docs/src/extends.md +++ b/docs/src/extends.md @@ -18,7 +18,7 @@ Unlike full-fledged distributions, a sampler, in general, only provides limited To implement a univariate sampler, one can define a subtype (say `Spl`) of `Sampleable{Univariate,S}` (where `S` can be `Discrete` or `Continuous`), and provide a `rand` method, as ```julia -function rand(rng::AbstractRNG, s::Spl) +function Distributions.rand(rng::AbstractRNG, s::Spl) # ... generate a single sample from s end ``` @@ -32,7 +32,7 @@ To implement a multivariate sampler, one can define a subtype of `Sampleable{Mul ```julia Base.length(s::Spl) = ... # return the length of each sample -function _rand!(rng::AbstractRNG, s::Spl, x::AbstractVector{T}) where T<:Real +function Distributions._rand!(rng::AbstractRNG, s::Spl, x::AbstractVector{T}) where T<:Real # ... generate a single vector sample to x end ``` @@ -80,7 +80,7 @@ Remember that each *column* of A is a sample. ### Matrix-variate Sampler -To implement a multivariate sampler, one can define a subtype of `Sampleable{Multivariate,S}`, and provide both `size` and `_rand!` methods, as +To implement a matrix-variate sampler, one can define a subtype of `Sampleable{Matrixvariate,S}`, and provide both `size` and `_rand!` methods, as ```julia Base.size(s::Spl) = ... # the size of each matrix sample @@ -104,7 +104,7 @@ sampler(d::Distribution) A univariate distribution type should be defined as a subtype of `DiscreteUnivarateDistribution` or `ContinuousUnivariateDistribution`. -The following methods need to be implemented for each univariate distribution type: +The following methods need to be implemented for each univariate distribution type (qualify each with `Distributions.`): - [`rand(::AbstractRNG, d::UnivariateDistribution)`](@ref) - [`sampler(d::Distribution)`](@ref) @@ -115,7 +115,7 @@ The following methods need to be implemented for each univariate distribution ty - [`maximum(d::UnivariateDistribution)`](@ref) - [`insupport(d::UnivariateDistribution, x::Real)`](@ref) -It is also recommended that one also implements the following statistics functions: +It is also recommended that one also implements the following statistics functions (qualify each with `Distributions.`): - [`mean(d::UnivariateDistribution)`](@ref) - [`var(d::UnivariateDistribution)`](@ref) @@ -139,10 +139,10 @@ The following methods need to be implemented for each multivariate distribution - [`length(d::MultivariateDistribution)`](@ref) - [`sampler(d::Distribution)`](@ref) - [`eltype(d::Distribution)`](@ref) -- [`Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractArray)`](@ref) -- [`Distributions._logpdf(d::MultivariateDistribution, x::AbstractArray)`](@ref) +- [`Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractVector{<:Real})`](@ref) +- [`Distributions._logpdf(d::MultivariateDistribution, x::AbstractVector{<:Real})`](@ref) -Note that if there exist faster methods for batch evaluation, one should override `_logpdf!` and `_pdf!`. +Note that if there exist faster methods for batch evaluation, one may also override `Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractMatrix{<:Real})` and [`Distributions._logpdf!`](@ref). Furthermore, the generic `loglikelihood` function repeatedly calls `_logpdf`. If there is a better way to compute the log-likelihood, one should override `loglikelihood`. @@ -161,6 +161,6 @@ A matrix-variate distribution type should be defined as a subtype of `DiscreteMa The following methods need to be implemented for each matrix-variate distribution type: - [`size(d::MatrixDistribution)`](@ref) -- [`Distributions._rand!(rng::AbstractRNG, d::MatrixDistribution, A::AbstractMatrix)`](@ref) +- [`Distributions._rand!(rng::AbstractRNG, d::MatrixDistribution, A::AbstractMatrix{<:Real})`](@ref) - [`sampler(d::MatrixDistribution)`](@ref) -- [`Distributions._logpdf(d::MatrixDistribution, x::AbstractArray)`](@ref) +- [`Distributions._logpdf(d::MatrixDistribution, x::AbstractMatrix{<:Real})`](@ref) diff --git a/src/multivariates.jl b/src/multivariates.jl index 56d91233cf..9548905249 100644 --- a/src/multivariates.jl +++ b/src/multivariates.jl @@ -107,6 +107,33 @@ function cor(d::MultivariateDistribution) return R end +""" + Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractVector) + +Internal function for generating samples from `d` into `x`. When creating new multivariate distributions, +one should implement this method at least for `x::AbstractVector`. If there are faster +methods for creating samples in a batch, then consider implementing it also for `x::AbstractMatrix` +where each sample is one column of `x`. +""" +function _rand! end + +""" + Distributions._logpdf(d::MultivariateDistribution, x::AbstractVector{<:Real}) + +Internal function for computing the log-density of `d` at `x`. When creating new multivariate +distributions, one should implement this method at least for `x::AbstractVector{<:Real}`. If there are +faster methods for computing the log-density in a batch, then consider implementing +[`Distributions._logpdf!`](@ref). +""" +function _logpdf end + +""" + Distributions._logpdf!(r::AbstractArray{<:Real}, d::MultivariateDistribution, x::AbstractMatrix{<:Real}) + +An optional method to implement for multivariate distributions, computing `logpdf` for each column in `x`. +""" +function _logpdf! end + ##### Specific distributions ##### for fname in ["dirichlet.jl", diff --git a/test/extensions.jl b/test/extensions.jl new file mode 100644 index 0000000000..e85c268892 --- /dev/null +++ b/test/extensions.jl @@ -0,0 +1,192 @@ +# Test the extension interface described in https://juliastats.org/Distributions.jl/stable/extends/#Create-New-Samplers-and-Distributions + +module Extensions + +using Distributions +using Random + +### Samplers + +## Univariate Sampler +struct Dirac1Sampler{T} <: Sampleable{Univariate, Continuous} + x::T +end + +Distributions.rand(::AbstractRNG, s::Dirac1Sampler) = s.x + +## Multivariate Sampler +struct DiracNSampler{T} <: Sampleable{Multivariate, Continuous} + x::Vector{T} +end + +Base.length(s::DiracNSampler) = length(s.x) +Distributions._rand!(::AbstractRNG, s::DiracNSampler, x::AbstractVector) = x .= s.x + +## Matrix-variate sampler +struct DiracMVSampler{T} <: Sampleable{Matrixvariate, Continuous} + x::Matrix{T} +end + +Base.size(s::DiracMVSampler) = size(s.x) +Distributions._rand!(::AbstractRNG, s::DiracMVSampler, x::AbstractMatrix) = x .= s.x + + + +### Distributions + +## Univariate distribution +struct Dirac1{T} <: ContinuousUnivariateDistribution + x::T +end + +# required methods +Distributions.rand(::AbstractRNG, d::Dirac1) = d.x +Distributions.logpdf(d::Dirac1, x::Real) = x == d.x ? Inf : 0.0 +Distributions.cdf(d::Dirac1, x::Real) = x < d.x ? false : true +function Distributions.quantile(d::Dirac1, p::Real) + (p < zero(p) || p > oneunit(p)) && throw(DomainError()) + return iszero(p) ? typemin(d.x) : d.x +end +Distributions.minimum(d::Dirac1) = typemin(d.x) +Distributions.maximum(d::Dirac1) = typemax(d.x) +Distributions.insupport(d::Dirac1, x::Real) = minimum(d) < x < maximum(d) + +# recommended methods +Distributions.mean(d::Dirac1) = d.x +Distributions.var(d::Dirac1) = zero(d.x) +Distributions.mode(d::Dirac1) = d.x +# Distributions.modes(d::Dirac1) = [mode(d)] # test the fallback +Distributions.skewness(d::Dirac1) = zero(d.x) +Distributions.kurtosis(d::Dirac1, ::Bool) = zero(d.x) # conceived as the limit of a Gaussian for σ → 0 +Distributions.entropy(d::Dirac1) = zero(d.x) +Distributions.mgf(d::Dirac1, t::Real) = exp(t * d.x) +Distributions.cf(d::Dirac1, t::Real) = exp(t * d.x * im) + + +## Multivariate distribution +struct DiracN{T} <: ContinuousMultivariateDistribution + x::Vector{T} +end + +# required methods +Base.length(d::DiracN) = length(d.x) +Base.eltype(::DiracN{T}) where T = T +Distributions._rand!(::AbstractRNG, d::DiracN, x::AbstractVector) = x .= d.x +Distributions._rand!(::AbstractRNG, d::DiracN, x::AbstractMatrix) = x .= d.x +Distributions._logpdf(d::DiracN, x::AbstractVector{<:Real}) = x == d.x ? Inf : 0.0 +Distributions._logpdf(d::DiracN, x::AbstractMatrix{<:Real}) = map(y -> y == d.x ? Inf : 0.0, eachcol(x)) + +# recommended methods +Distributions.mean(d::DiracN) = d.x +Distributions.var(d::DiracN) = zero(d.x) +Distributions.entropy(::DiracN{T}) where T = zero(T) +Distributions.cov(d::DiracN) = zero(d.x) * zero(d.x)' + + +## Matrix-variate distribution +struct DiracMV{T} <: ContinuousMatrixDistribution + x::Matrix{T} +end + +# required methods +Base.size(d::DiracMV) = size(d.x) +Distributions._rand!(::AbstractRNG, d::DiracMV, x::AbstractMatrix) = x .= d.x +Distributions._logpdf(d::DiracMV, x::AbstractMatrix{<:Real}) = x == d.x ? Inf : 0.0 + + +end # module Extensions + +using Distributions +using Random +using Test + +@testset "Extensions" begin + ## Samplers + # Univariate + s = Extensions.Dirac1Sampler(1.0) + @test rand(s) == 1.0 + @test rand(s, 5) == ones(5) + @test rand!(s, zeros(5)) == ones(5) + # Multivariate + s = Extensions.DiracNSampler([1.0, 2.0, 3.0]) + @test rand(s) == [1.0, 2.0, 3.0] + @test rand(s, 5) == rand!(s, zeros(3, 5)) == repeat([1.0, 2.0, 3.0], 1, 5) + # Matrix-variate + s = Extensions.DiracMVSampler([1.0 2.0 3.0; 4.0 5.0 6.0]) + @test rand(s) == [1.0 2.0 3.0; 4.0 5.0 6.0] + @test rand(s, 5) == rand!(s, [zeros(2, 3) for i=1:5]) == [[1.0 2.0 3.0; 4.0 5.0 6.0] for i = 1:5] + + ## Distributions + # Univariate + d = Extensions.Dirac1(1.0) + @test rand(d) == 1.0 + @test rand(d, 5) == ones(5) + @test rand!(d, zeros(5)) == ones(5) + @test logpdf(d, 1.0) == Inf + @test logpdf(d, 2.0) == 0.0 + @test cdf(d, 0.0) == false + @test cdf(d, 1.0) == true + @test cdf(d, 2.0) == true + @test quantile(d, 0.0) == -Inf + @test quantile(d, 0.5) == 1.0 + @test quantile(d, 1.0) == 1.0 + @test minimum(d) == -Inf + @test maximum(d) == Inf + @test insupport(d, 0.0) == true + @test insupport(d, 1.0) == true + @test insupport(d, -Inf) == false + @test mean(d) == 1.0 + @test var(d) == 0.0 + @test mode(d) == 1.0 + @test skewness(d) == 0.0 + @test_broken kurtosis(d) == 0.0 + @test entropy(d) == 0.0 + @test mgf(d, 0.0) == 1.0 + @test mgf(d, 1.0) == exp(1.0) + @test cf(d, 0.0) == 1.0 + @test cf(d, 1.0) == exp(im) + # MixtureModel of Univariate + d = MixtureModel([Extensions.Dirac1(1.0), Extensions.Dirac1(2.0), Extensions.Dirac1(3.0)]) + @test rand(d) ∈ (1.0, 2.0, 3.0) + @test all(∈((1.0, 2.0, 3.0)), rand(d, 5)) + @test all(∈((1.0, 2.0, 3.0)), rand!(d, zeros(5))) + @test logpdf(d, 1.5) == 0.0 + @test logpdf(d, 2) == Inf + @test logpdf(d, [0.5, 2.0, 2.5]) == [0.0, Inf, 0.0] + @test mean(d) == 2 + + # Multivariate + d = Extensions.DiracN([1.0, 2.0, 3.0]) + @test length(d) == 3 + @test eltype(d) == Float64 + @test rand(d) == [1.0, 2.0, 3.0] + @test rand(d, 5) == rand!(d, zeros(3, 5)) == repeat([1.0, 2.0, 3.0], 1, 5) + @test logpdf(d, [1.0, 2, 3]) == Inf + @test logpdf(d, [1.0, 2, 4]) == 0.0 + @test logpdf(d, [1.0 1; 2 2; 3 4]) == [Inf, 0.0] + @test mean(d) == [1.0, 2.0, 3.0] + @test var(d) == [0.0, 0.0, 0.0] + @test entropy(d) == 0.0 + @test cov(d) == zeros(3, 3) + # Mixture model of multivariate + d = MixtureModel([Extensions.DiracN([1.0, 2.0, 3.0]), Extensions.DiracN([4.0, 5.0, 6.0])]) + @test rand(d) ∈ ([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]) + @test all(∈(([1.0, 2.0, 3.0], [4.0, 5.0, 6.0])), eachcol(rand(d, 5))) + @test all(∈(([1.0, 2.0, 3.0], [4.0, 5.0, 6.0])), eachcol(rand!(d, zeros(3, 5)))) + @test logpdf(d, [1.0, 2, 3]) == Inf + @test logpdf(d, [4.0, 5, 6]) == Inf + @test logpdf(d, [1.0, 2, 4]) == 0.0 + + # Matrix-variate + d = Extensions.DiracMV([1.0 2.0 3.0; 4.0 5.0 6.0]) + @test size(d) == (2, 3) + @test rand(d) == [1.0 2.0 3.0; 4.0 5.0 6.0] + @test rand(d, 5) == rand!(d, [zeros(2, 3) for i=1:5]) == [[1.0 2.0 3.0; 4.0 5.0 6.0] for i = 1:5] + @test logpdf(d, [1.0 2.0 3.0; 4.0 5.0 6.0]) == Inf + @test logpdf(d, [1.0 2.0 3.0; 4.0 5.0 7.0]) == 0.0 + @test logpdf(d, [[1.0 2.0 3.0; 4.0 5.0 7.0], [1.0 2.0 3.0; 4.0 5.0 6.0]]) == [0.0, Inf] + # Mixtures of matrix-variate + d = MixtureModel([Extensions.DiracMV([1.0 2.0 3.0; 4.0 5.0 6.0]), Extensions.DiracMV([7.0 8.0 9.0; 10.0 11.0 12.0])]) + @test_broken rand(d) ∈ ([1.0 2.0 3.0; 4.0 5.0 6.0], [7.0 8.0 9.0; 10.0 11.0 12.0]) + @test_broken all(∈(([1.0 2.0 3.0; 4.0 5.0 6.0], [7.0 8.0 9.0; 10.0 11.0 12.0])), eachslice(rand(d, 5), dims=3)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 583132c536..a913510656 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -96,6 +96,7 @@ const tests = [ "eachvariate", "univariate/continuous/triangular", "statsapi", + "extensions", ### missing files compared to /src: # "common",