Skip to content

Commit c430f39

Browse files
authored
Generalize Product (#1391)
* Move `src/multivariate/product.jl` * Generalize `Product` to `ProductDistribution` * Add implementations for more general product distributions * Unify and generalize `rand!`, `logpdf` and `pdf` * Revert unrelated changes and fix tests * Propagate `@inbounds` * Remove unneeded implementation * Fix typos * Fix some dispatches * More fixes * Support tuple of distributions and mix of discrete + continuous * Fix additional test errors * Fix method ambiguity * Fix `VonMisesFisherSampler` * Fix mixture sampler * Simplify multinomial sampler * Fix `loglikelihood` for univariate distributions * Add ReshapedDistribution * Fix typo * Revert some changes * Update product.jl * Remove duplicate `eachvariate`/`EachVariate` * Reintroduce `Product` * Improve type inference * Add explanations of `ValueSupport` * Fix typo * Remove another breaking change
1 parent 38dafb7 commit c430f39

File tree

7 files changed

+563
-43
lines changed

7 files changed

+563
-43
lines changed

docs/src/types.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,21 @@ The `VariateForm` sub-types defined in `Distributions.jl` are:
3333

3434
### ValueSupport
3535

36-
```@doc
36+
```@docs
3737
Distributions.ValueSupport
3838
```
3939

4040
The `ValueSupport` sub-types defined in `Distributions.jl` are:
4141

42-
**Type** | **Element type** | **Descriptions**
43-
--- | --- | ---
44-
`Discrete` | `Int` | Samples take discrete values
45-
`Continuous` | `Float64` | Samples take continuous real values
42+
```@docs
43+
Distributions.Discrete
44+
Distributions.Continuous
45+
```
46+
47+
**Type** | **Default element type** | **Description** | **Examples**
48+
--- | --- | --- | ---
49+
`Discrete` | `Int` | Samples take countably many values | $\{0,1,2,3\}$, $\mathbb{N}$
50+
`Continuous` | `Float64` | Samples take uncountably many values | $[0, 1]$, $\mathbb{R}$
4651

4752
Multiple samples are often organized into an array, depending on the variate form.
4853

src/Distributions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ export
145145
Pareto,
146146
PGeneralizedGaussian,
147147
SkewedExponentialPower,
148-
Product,
148+
Product, # deprecated
149149
Poisson,
150150
PoissonBinomial,
151151
QQPair,
@@ -293,6 +293,7 @@ include("cholesky/lkjcholesky.jl")
293293
include("samplers.jl")
294294

295295
# others
296+
include("product.jl")
296297
include("reshaped.jl")
297298
include("truncate.jl")
298299
include("censored.jl")

src/common.jl

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,42 @@ const Matrixvariate = ArrayLikeVariate{2}
2323
abstract type CholeskyVariate <: VariateForm end
2424

2525
"""
26-
`S <: ValueSupport` specifies the support of sample elements,
27-
either discrete or continuous.
26+
ValueSupport
27+
28+
Abstract type that specifies the support of elements of samples.
29+
30+
It is either [`Discrete`](@ref) or [`Continuous`](@ref).
2831
"""
2932
abstract type ValueSupport end
33+
34+
"""
35+
Discrete <: ValueSupport
36+
37+
This type represents the support of a discrete random variable.
38+
39+
It is countable. For instance, it can be a finite set or a countably infinite set such as
40+
the natural numbers.
41+
42+
See also: [`Continuous`](@ref), [`ValueSupport`](@ref)
43+
"""
3044
struct Discrete <: ValueSupport end
45+
46+
"""
47+
Continuous <: ValueSupport
48+
49+
This types represents the support of a continuous random variable.
50+
51+
It is uncountably infinite. For instance, it can be an interval on the real line.
52+
53+
See also: [`Discrete`](@ref), [`ValueSupport`](@ref)
54+
"""
3155
struct Continuous <: ValueSupport end
3256

57+
# promotions (e.g., in product distribution):
58+
# combination of discrete support (countable) and continuous support (uncountable) yields
59+
# continuous support (uncountable)
60+
Base.promote_rule(::Type{Continuous}, ::Type{Discrete}) = Continuous
61+
3362
## Sampleable
3463

3564
"""
@@ -42,7 +71,6 @@ Any `Sampleable` implements the `Base.rand` method.
4271
"""
4372
abstract type Sampleable{F<:VariateForm,S<:ValueSupport} end
4473

45-
4674
variate_form(::Type{<:Sampleable{VF}}) where {VF} = VF
4775
value_support(::Type{<:Sampleable{<:VariateForm,VS}}) where {VS} = VS
4876

@@ -142,10 +170,6 @@ const ContinuousMultivariateDistribution = Distribution{Multivariate, Continuou
142170
const DiscreteMatrixDistribution = Distribution{Matrixvariate, Discrete}
143171
const ContinuousMatrixDistribution = Distribution{Matrixvariate, Continuous}
144172

145-
variate_form(::Type{<:Distribution{VF}}) where {VF} = VF
146-
147-
value_support(::Type{<:Distribution{VF,VS}}) where {VF,VS} = VS
148-
149173
# allow broadcasting over distribution objects
150174
# to be decided: how to handle multivariate/matrixvariate distributions?
151175
Broadcast.broadcastable(d::UnivariateDistribution) = Ref(d)

src/multivariate/product.jl

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import Statistics: mean, var, cov
1+
# Deprecated product distribution
2+
# TODO: Remove in next breaking release
23

34
"""
45
Product <: MultivariateDistribution
@@ -20,6 +21,10 @@ struct Product{
2021
V<:AbstractVector{T} where
2122
T<:UnivariateDistribution{S} where
2223
S<:ValueSupport
24+
Base.depwarn(
25+
"`Product(v)` is deprecated, please use `product_distribution(v)`",
26+
:Product,
27+
)
2328
return new{S, T, V}(v)
2429
end
2530
end
@@ -43,26 +48,9 @@ insupport(d::Product, x::AbstractVector) = all(insupport.(d.v, x))
4348
minimum(d::Product) = map(minimum, d.v)
4449
maximum(d::Product) = map(maximum, d.v)
4550

46-
"""
47-
product_distribution(dists::AbstractVector{<:UnivariateDistribution})
48-
49-
Creates a multivariate product distribution `P` from a vector of univariate distributions.
50-
Fallback is the `Product constructor`, but specialized methods can be defined
51-
for distributions with a special multivariate product.
52-
"""
53-
function product_distribution(dists::AbstractVector{<:UnivariateDistribution})
54-
return Product(dists)
55-
end
56-
57-
"""
58-
product_distribution(dists::AbstractVector{<:Normal})
59-
60-
Computes the multivariate Normal distribution obtained by stacking the univariate
61-
normal distributions. The result is a multivariate Gaussian with a diagonal
62-
covariance matrix.
63-
"""
64-
function product_distribution(dists::AbstractVector{<:Normal})
65-
µ = mean.(dists)
66-
σ2 = var.(dists)
67-
return MvNormal(µ, Diagonal(σ2))
68-
end
51+
# TODO: remove deprecation when `Product` is removed
52+
# it will return a `ProductDistribution` then which is already the default for
53+
# higher-dimensional arrays and distributions
54+
Base.@deprecate product_distribution(
55+
dists::AbstractVector{<:UnivariateDistribution}
56+
) Product(dists)

src/multivariates.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ for fname in ["dirichlet.jl",
116116
"mvnormalcanon.jl",
117117
"mvlognormal.jl",
118118
"mvtdist.jl",
119-
"product.jl",
119+
"product.jl", # deprecated
120120
"vonmisesfisher.jl"]
121121
include(joinpath("multivariate", fname))
122122
end

0 commit comments

Comments
 (0)