Skip to content

Tweak docs, add tests for extension interface #1825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions docs/src/extends.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Unlike full-fledged distributions, a sampler, in general, only provides limited
To implement a univariate sampler, one can define a subtype (say `Spl`) of `Sampleable{Univariate,S}` (where `S` can be `Discrete` or `Continuous`), and provide a `rand` method, as

```julia
function rand(rng::AbstractRNG, s::Spl)
function Distributions.rand(rng::AbstractRNG, s::Spl)
# ... generate a single sample from s
end
```
Expand All @@ -32,7 +32,7 @@ To implement a multivariate sampler, one can define a subtype of `Sampleable{Mul
```julia
Base.length(s::Spl) = ... # return the length of each sample

function _rand!(rng::AbstractRNG, s::Spl, x::AbstractVector{T}) where T<:Real
function Distributions._rand!(rng::AbstractRNG, s::Spl, x::AbstractVector{T}) where T<:Real
# ... generate a single vector sample to x
end
```
Expand Down Expand Up @@ -80,7 +80,7 @@ Remember that each *column* of A is a sample.

### Matrix-variate Sampler

To implement a multivariate sampler, one can define a subtype of `Sampleable{Multivariate,S}`, and provide both `size` and `_rand!` methods, as
To implement a matrix-variate sampler, one can define a subtype of `Sampleable{Matrixvariate,S}`, and provide both `size` and `_rand!` methods, as

```julia
Base.size(s::Spl) = ... # the size of each matrix sample
Expand All @@ -104,7 +104,7 @@ sampler(d::Distribution)

A univariate distribution type should be defined as a subtype of `DiscreteUnivarateDistribution` or `ContinuousUnivariateDistribution`.

The following methods need to be implemented for each univariate distribution type:
The following methods need to be implemented for each univariate distribution type (qualify each with `Distributions.`):

- [`rand(::AbstractRNG, d::UnivariateDistribution)`](@ref)
- [`sampler(d::Distribution)`](@ref)
Expand All @@ -115,7 +115,7 @@ The following methods need to be implemented for each univariate distribution ty
- [`maximum(d::UnivariateDistribution)`](@ref)
- [`insupport(d::UnivariateDistribution, x::Real)`](@ref)

It is also recommended that one also implements the following statistics functions:
It is also recommended that one also implements the following statistics functions (qualify each with `Distributions.`):

- [`mean(d::UnivariateDistribution)`](@ref)
- [`var(d::UnivariateDistribution)`](@ref)
Expand All @@ -139,10 +139,10 @@ The following methods need to be implemented for each multivariate distribution
- [`length(d::MultivariateDistribution)`](@ref)
- [`sampler(d::Distribution)`](@ref)
- [`eltype(d::Distribution)`](@ref)
- [`Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractArray)`](@ref)
- [`Distributions._logpdf(d::MultivariateDistribution, x::AbstractArray)`](@ref)
- [`Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractVector{<:Real})`](@ref)
- [`Distributions._logpdf(d::MultivariateDistribution, x::AbstractVector{<:Real})`](@ref)

Note that if there exist faster methods for batch evaluation, one should override `_logpdf!` and `_pdf!`.
Note that if there exist faster methods for batch evaluation, one may also override `Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractMatrix{<:Real})` and [`Distributions._logpdf!`](@ref).

Furthermore, the generic `loglikelihood` function repeatedly calls `_logpdf`. If there is
a better way to compute the log-likelihood, one should override `loglikelihood`.
Expand All @@ -161,6 +161,6 @@ A matrix-variate distribution type should be defined as a subtype of `DiscreteMa
The following methods need to be implemented for each matrix-variate distribution type:

- [`size(d::MatrixDistribution)`](@ref)
- [`Distributions._rand!(rng::AbstractRNG, d::MatrixDistribution, A::AbstractMatrix)`](@ref)
- [`Distributions._rand!(rng::AbstractRNG, d::MatrixDistribution, A::AbstractMatrix{<:Real})`](@ref)
- [`sampler(d::MatrixDistribution)`](@ref)
- [`Distributions._logpdf(d::MatrixDistribution, x::AbstractArray)`](@ref)
- [`Distributions._logpdf(d::MatrixDistribution, x::AbstractMatrix{<:Real})`](@ref)
27 changes: 27 additions & 0 deletions src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,33 @@ function cor(d::MultivariateDistribution)
return R
end

"""
Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractVector)

Internal function for generating samples from `d` into `x`. When creating new multivariate distributions,
one should implement this method at least for `x::AbstractVector`. If there are faster
methods for creating samples in a batch, then consider implementing it also for `x::AbstractMatrix`
where each sample is one column of `x`.
"""
function _rand! end

"""
Distributions._logpdf(d::MultivariateDistribution, x::AbstractVector{<:Real})

Internal function for computing the log-density of `d` at `x`. When creating new multivariate
distributions, one should implement this method at least for `x::AbstractVector{<:Real}`. If there are
faster methods for computing the log-density in a batch, then consider implementing
[`Distributions._logpdf!`](@ref).
"""
function _logpdf end

"""
Distributions._logpdf!(r::AbstractArray{<:Real}, d::MultivariateDistribution, x::AbstractMatrix{<:Real})

An optional method to implement for multivariate distributions, computing `logpdf` for each column in `x`.
"""
function _logpdf! end

##### Specific distributions #####

for fname in ["dirichlet.jl",
Expand Down
192 changes: 192 additions & 0 deletions test/extensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Test the extension interface described in https://juliastats.org/Distributions.jl/stable/extends/#Create-New-Samplers-and-Distributions

module Extensions

using Distributions
using Random

### Samplers

## Univariate Sampler
struct Dirac1Sampler{T} <: Sampleable{Univariate, Continuous}
x::T
end

Distributions.rand(::AbstractRNG, s::Dirac1Sampler) = s.x

## Multivariate Sampler
struct DiracNSampler{T} <: Sampleable{Multivariate, Continuous}
x::Vector{T}
end

Base.length(s::DiracNSampler) = length(s.x)
Distributions._rand!(::AbstractRNG, s::DiracNSampler, x::AbstractVector) = x .= s.x

## Matrix-variate sampler
struct DiracMVSampler{T} <: Sampleable{Matrixvariate, Continuous}
x::Matrix{T}
end

Base.size(s::DiracMVSampler) = size(s.x)
Distributions._rand!(::AbstractRNG, s::DiracMVSampler, x::AbstractMatrix) = x .= s.x



### Distributions

## Univariate distribution
struct Dirac1{T} <: ContinuousUnivariateDistribution
x::T
end

# required methods
Distributions.rand(::AbstractRNG, d::Dirac1) = d.x
Distributions.logpdf(d::Dirac1, x::Real) = x == d.x ? Inf : 0.0
Distributions.cdf(d::Dirac1, x::Real) = x < d.x ? false : true
function Distributions.quantile(d::Dirac1, p::Real)
(p < zero(p) || p > oneunit(p)) && throw(DomainError())
return iszero(p) ? typemin(d.x) : d.x
end
Distributions.minimum(d::Dirac1) = typemin(d.x)
Distributions.maximum(d::Dirac1) = typemax(d.x)
Distributions.insupport(d::Dirac1, x::Real) = minimum(d) < x < maximum(d)

# recommended methods
Distributions.mean(d::Dirac1) = d.x
Distributions.var(d::Dirac1) = zero(d.x)
Distributions.mode(d::Dirac1) = d.x
# Distributions.modes(d::Dirac1) = [mode(d)] # test the fallback
Distributions.skewness(d::Dirac1) = zero(d.x)
Distributions.kurtosis(d::Dirac1, ::Bool) = zero(d.x) # conceived as the limit of a Gaussian for σ → 0
Distributions.entropy(d::Dirac1) = zero(d.x)
Distributions.mgf(d::Dirac1, t::Real) = exp(t * d.x)
Distributions.cf(d::Dirac1, t::Real) = exp(t * d.x * im)


## Multivariate distribution
struct DiracN{T} <: ContinuousMultivariateDistribution
x::Vector{T}
end

# required methods
Base.length(d::DiracN) = length(d.x)
Base.eltype(::DiracN{T}) where T = T
Distributions._rand!(::AbstractRNG, d::DiracN, x::AbstractVector) = x .= d.x
Distributions._rand!(::AbstractRNG, d::DiracN, x::AbstractMatrix) = x .= d.x
Distributions._logpdf(d::DiracN, x::AbstractVector{<:Real}) = x == d.x ? Inf : 0.0
Distributions._logpdf(d::DiracN, x::AbstractMatrix{<:Real}) = map(y -> y == d.x ? Inf : 0.0, eachcol(x))

# recommended methods
Distributions.mean(d::DiracN) = d.x
Distributions.var(d::DiracN) = zero(d.x)
Distributions.entropy(::DiracN{T}) where T = zero(T)
Distributions.cov(d::DiracN) = zero(d.x) * zero(d.x)'


## Matrix-variate distribution
struct DiracMV{T} <: ContinuousMatrixDistribution
x::Matrix{T}
end

# required methods
Base.size(d::DiracMV) = size(d.x)
Distributions._rand!(::AbstractRNG, d::DiracMV, x::AbstractMatrix) = x .= d.x
Distributions._logpdf(d::DiracMV, x::AbstractMatrix{<:Real}) = x == d.x ? Inf : 0.0


end # module Extensions

using Distributions
using Random
using Test

@testset "Extensions" begin
## Samplers
# Univariate
s = Extensions.Dirac1Sampler(1.0)
@test rand(s) == 1.0
@test rand(s, 5) == ones(5)
@test rand!(s, zeros(5)) == ones(5)
# Multivariate
s = Extensions.DiracNSampler([1.0, 2.0, 3.0])
@test rand(s) == [1.0, 2.0, 3.0]
@test rand(s, 5) == rand!(s, zeros(3, 5)) == repeat([1.0, 2.0, 3.0], 1, 5)
# Matrix-variate
s = Extensions.DiracMVSampler([1.0 2.0 3.0; 4.0 5.0 6.0])
@test rand(s) == [1.0 2.0 3.0; 4.0 5.0 6.0]
@test rand(s, 5) == rand!(s, [zeros(2, 3) for i=1:5]) == [[1.0 2.0 3.0; 4.0 5.0 6.0] for i = 1:5]

## Distributions
# Univariate
d = Extensions.Dirac1(1.0)
@test rand(d) == 1.0
@test rand(d, 5) == ones(5)
@test rand!(d, zeros(5)) == ones(5)
@test logpdf(d, 1.0) == Inf
@test logpdf(d, 2.0) == 0.0
@test cdf(d, 0.0) == false
@test cdf(d, 1.0) == true
@test cdf(d, 2.0) == true
@test quantile(d, 0.0) == -Inf
@test quantile(d, 0.5) == 1.0
@test quantile(d, 1.0) == 1.0
@test minimum(d) == -Inf
@test maximum(d) == Inf
@test insupport(d, 0.0) == true
@test insupport(d, 1.0) == true
@test insupport(d, -Inf) == false
@test mean(d) == 1.0
@test var(d) == 0.0
@test mode(d) == 1.0
@test skewness(d) == 0.0
@test_broken kurtosis(d) == 0.0
@test entropy(d) == 0.0
@test mgf(d, 0.0) == 1.0
@test mgf(d, 1.0) == exp(1.0)
@test cf(d, 0.0) == 1.0
@test cf(d, 1.0) == exp(im)
# MixtureModel of Univariate
d = MixtureModel([Extensions.Dirac1(1.0), Extensions.Dirac1(2.0), Extensions.Dirac1(3.0)])
@test rand(d) ∈ (1.0, 2.0, 3.0)
@test all(∈((1.0, 2.0, 3.0)), rand(d, 5))
@test all(∈((1.0, 2.0, 3.0)), rand!(d, zeros(5)))
@test logpdf(d, 1.5) == 0.0
@test logpdf(d, 2) == Inf
@test logpdf(d, [0.5, 2.0, 2.5]) == [0.0, Inf, 0.0]
@test mean(d) == 2

# Multivariate
d = Extensions.DiracN([1.0, 2.0, 3.0])
@test length(d) == 3
@test eltype(d) == Float64
@test rand(d) == [1.0, 2.0, 3.0]
@test rand(d, 5) == rand!(d, zeros(3, 5)) == repeat([1.0, 2.0, 3.0], 1, 5)
@test logpdf(d, [1.0, 2, 3]) == Inf
@test logpdf(d, [1.0, 2, 4]) == 0.0
@test logpdf(d, [1.0 1; 2 2; 3 4]) == [Inf, 0.0]
@test mean(d) == [1.0, 2.0, 3.0]
@test var(d) == [0.0, 0.0, 0.0]
@test entropy(d) == 0.0
@test cov(d) == zeros(3, 3)
# Mixture model of multivariate
d = MixtureModel([Extensions.DiracN([1.0, 2.0, 3.0]), Extensions.DiracN([4.0, 5.0, 6.0])])
@test rand(d) ∈ ([1.0, 2.0, 3.0], [4.0, 5.0, 6.0])
@test all(∈(([1.0, 2.0, 3.0], [4.0, 5.0, 6.0])), eachcol(rand(d, 5)))
@test all(∈(([1.0, 2.0, 3.0], [4.0, 5.0, 6.0])), eachcol(rand!(d, zeros(3, 5))))
@test logpdf(d, [1.0, 2, 3]) == Inf
@test logpdf(d, [4.0, 5, 6]) == Inf
@test logpdf(d, [1.0, 2, 4]) == 0.0

# Matrix-variate
d = Extensions.DiracMV([1.0 2.0 3.0; 4.0 5.0 6.0])
@test size(d) == (2, 3)
@test rand(d) == [1.0 2.0 3.0; 4.0 5.0 6.0]
@test rand(d, 5) == rand!(d, [zeros(2, 3) for i=1:5]) == [[1.0 2.0 3.0; 4.0 5.0 6.0] for i = 1:5]
@test logpdf(d, [1.0 2.0 3.0; 4.0 5.0 6.0]) == Inf
@test logpdf(d, [1.0 2.0 3.0; 4.0 5.0 7.0]) == 0.0
@test logpdf(d, [[1.0 2.0 3.0; 4.0 5.0 7.0], [1.0 2.0 3.0; 4.0 5.0 6.0]]) == [0.0, Inf]
# Mixtures of matrix-variate
d = MixtureModel([Extensions.DiracMV([1.0 2.0 3.0; 4.0 5.0 6.0]), Extensions.DiracMV([7.0 8.0 9.0; 10.0 11.0 12.0])])
@test_broken rand(d) ∈ ([1.0 2.0 3.0; 4.0 5.0 6.0], [7.0 8.0 9.0; 10.0 11.0 12.0])
@test_broken all(∈(([1.0 2.0 3.0; 4.0 5.0 6.0], [7.0 8.0 9.0; 10.0 11.0 12.0])), eachslice(rand(d, 5), dims=3))
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ const tests = [
"eachvariate",
"univariate/continuous/triangular",
"statsapi",
"extensions",

### missing files compared to /src:
# "common",
Expand Down