Skip to content

Commit da3a230

Browse files
Implement std(::MultivariateDistribution) (#1352)
* Implement std(::MultivariateDistribution) * Update src/multivariates.jl * Add `std` to docs * Add `std(::Dirichlet)` * Safer default implementation * More tests --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent 8bb4181 commit da3a230

12 files changed

+33
-0
lines changed

docs/src/multivariate.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ size(::MultivariateDistribution)
2121
eltype(::Type{MultivariateDistribution})
2222
mean(::MultivariateDistribution)
2323
var(::MultivariateDistribution)
24+
std(::MultivariateDistribution)
2425
cov(::MultivariateDistribution)
2526
cor(::MultivariateDistribution)
2627
entropy(::MultivariateDistribution)

ext/DistributionsTestExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ function Distributions.TestUtils.test_mvnormal(
3333
@test length(μ) == d
3434
@test size(Σ) == (d, d)
3535
@test var(g) diag(Σ)
36+
@test std(g) sqrt.(diag(Σ))
3637
@test entropy(g) 0.5 * logdet(2π ** Σ)
3738
ldcov = logdetcov(g)
3839
@test ldcov logdet(Σ)

src/multivariate/product.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ function _logpdf(d::Product, x::AbstractVector{<:Real})
5151
end
5252

5353
mean(d::Product) = mean.(d.v)
54+
std(d::Product) = std.(d.v)
5455
var(d::Product) = var.(d.v)
5556
cov(d::Product) = Diagonal(var(d))
5657
entropy(d::Product) = sum(entropy, d.v)

src/multivariates.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ Compute the vector of element-wise variances for distribution `d`.
6262
"""
6363
var(d::MultivariateDistribution)
6464

65+
"""
66+
std(d::MultivariateDistribution)
67+
68+
Compute the vector of element-wise standard deviations for distribution `d`.
69+
"""
70+
std(d::MultivariateDistribution) = sqrt!!(var(d))
71+
6572
"""
6673
entropy(d::MultivariateDistribution)
6774

src/product.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ cov(d::ProductDistribution) = Diagonal(vec(var(d)))
8383
## For product distributions of univariate distributions
8484
mean(d::ArrayOfUnivariateDistribution) = map(mean, d.dists)
8585
mean(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(mean, d.dists))
86+
std(d::ArrayOfUnivariateDistribution) = map(std, d.dists)
87+
std(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(std, d.dists))
8688
var(d::ArrayOfUnivariateDistribution) = map(var, d.dists)
8789
var(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(var, d.dists))
8890

src/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ isunitvec(v::AbstractVector) = (norm(v) - 1.0) < 1.0e-12
9797
isprobvec(p::AbstractVector{<:Real}) =
9898
all(x -> x zero(x), p) && isapprox(sum(p), one(eltype(p)))
9999

100+
sqrt!!(x::AbstractVector{<:Real}) = map(sqrt, x)
101+
function sqrt!!(x::Vector{<:Real})
102+
for i in eachindex(x)
103+
x[i] = sqrt(x[i])
104+
end
105+
return x
106+
end
107+
100108
# get a type wide enough to represent all a distributions's parameters
101109
# (if the distribution is parametric)
102110
# if the distribution is not parametric, we need this to be a float so that

test/multivariate/dirichlet.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ rng = MersenneTwister(123)
2626
@test mode(d) fill(1/3, 3)
2727
@test cov(d) [8 -4 -4; -4 8 -4; -4 -4 8] / (36 * 7)
2828
@test var(d) diag(cov(d))
29+
@test std(d) sqrt.(var(d))
2930

3031
r = Vector{Float64}(undef, 3)
3132
Distributions.dirichlet_mode!(r, d.alpha, d.alpha0)
@@ -65,6 +66,7 @@ rng = MersenneTwister(123)
6566
@test mean(d) v / sum(v)
6667
@test cov(d) [8 -2 -6; -2 5 -3; -6 -3 9] / (36 * 7)
6768
@test var(d) diag(cov(d))
69+
@test std(d) sqrt.(var(d))
6870

6971
@test pdf(d, [0.2, 0.3, 0.5]) 3
7072
@test pdf(d, [0.4, 0.5, 0.1]) 0.24

test/multivariate/dirichletmultinomial.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ d = DirichletMultinomial(10, α)
3333
@test mean(d) α * (d.n / d.α0)
3434
p = d.α / d.α0
3535
@test var(d) d.n * (d.n + d.α0) / (1 + d.α0) .* p .* (1.0 .- p)
36+
@test std(d) sqrt.(var(d))
3637
x = func[2](d, 10_000)
3738

3839
# test statistics with mle fit
3940
d = fit(DirichletMultinomial, x)
4041
@test isapprox(mean(d), vec(mean(x, dims=2)), atol=.5)
42+
@test isapprox(std(d) , vec(std(x, dims=2)) , atol=.5)
4143
@test isapprox(var(d) , vec(var(x, dims=2)) , atol=.5)
4244
@test isapprox(cov(d) , cov(x, dims=2) , atol=.5)
4345

test/multivariate/multinomial.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using Test
1919
@test length(d) == 3
2020
@test d.n == nt
2121
@test mean(d) T[2., 5., 3.]
22+
@test std(d) T[sqrt(1.6), sqrt(2.5), sqrt(2.1)]
2223
@test var(d) T[1.6, 2.5, 2.1]
2324
@test cov(d) T[1.6 -1.0 -0.6; -1.0 2.5 -1.5; -0.6 -1.5 2.1]
2425

test/multivariate/mvlognormal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ function test_mvlognormal(g::MvLogNormal, n_tsamples::Int=10^6,
3232
@test length(s) == d
3333
@test size(S) == (d, d)
3434
@test s diag(S)
35+
@test std(g) sqrt.(diag(S))
3536
@test md exp.(mean(g.normal))
3637
@test mn exp.(mean(g.normal) .+ var(g.normal)/2)
3738
@test mo exp.(mean(g.normal) .- var(g.normal))

test/multivariate/product.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ using Distributions: Product
2323
@test eltype(d_product) === eltype(ds[1])
2424
@test @inferred(logpdf(d_product, x)) sum(logpdf.(ds, x))
2525
@test mean(d_product) == mean.(ds)
26+
@test std(d_product) == std.(ds)
2627
@test var(d_product) == var.(ds)
2728
@test cov(d_product) == Diagonal(var.(ds))
2829
@test entropy(d_product) sum(entropy.(ds))
@@ -46,6 +47,7 @@ end
4647
@test eltype(d_product) === eltype(ds[1])
4748
@test @inferred(logpdf(d_product, x)) sum(logpdf.(ds, x))
4849
@test mean(d_product) == mean.(ds)
50+
@test std(d_product) == std.(ds)
4951
@test var(d_product) == var.(ds)
5052
@test cov(d_product) == Diagonal(var.(ds))
5153
@test entropy(d_product) == sum(entropy.(ds))
@@ -76,6 +78,7 @@ end
7678
@test eltype(d_product) === eltype(ds[1])
7779
@test @inferred(logpdf(d_product, x)) sum(logpdf.(ds, x))
7880
@test mean(d_product) == mean.(ds)
81+
@test std(d_product) == std.(ds)
7982
@test var(d_product) == var.(ds)
8083
@test cov(d_product) == Diagonal(var.(ds))
8184
@test entropy(d_product) == sum(entropy.(ds))

test/product.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ using LinearAlgebra
2525
@test length(d_product) == length(ds)
2626
@test eltype(d_product) === eltype(ds[1])
2727
@test mean(d_product) == mean.(ds)
28+
@test std(d_product) == std.(ds)
2829
@test var(d_product) == var.(ds)
2930
@test cov(d_product) == Diagonal(var.(ds))
3031
@test entropy(d_product) sum(entropy.(ds))
@@ -65,6 +66,7 @@ end
6566
@test length(d_product) == length(ds)
6667
@test eltype(d_product) === eltype(ds[1])
6768
@test @inferred(mean(d_product)) == mean.(ds)
69+
@test @inferred(std(d_product)) == std.(ds)
6870
@test @inferred(var(d_product)) == var.(ds)
6971
@test @inferred(cov(d_product)) == Diagonal(var.(ds))
7072
@test @inferred(entropy(d_product)) == sum(entropy.(ds))
@@ -115,6 +117,7 @@ end
115117
@test length(d_product) == length(ds)
116118
@test eltype(d_product) === eltype(ds[1])
117119
@test @inferred(mean(d_product)) == mean.(ds)
120+
@test @inferred(std(d_product)) == std.(ds)
118121
@test @inferred(var(d_product)) == var.(ds)
119122
@test @inferred(cov(d_product)) == Diagonal(var.(ds))
120123
@test @inferred(entropy(d_product)) == sum(entropy.(ds))
@@ -150,6 +153,7 @@ end
150153
@test length(d_product) == 3
151154
@test eltype(d_product) === Float64
152155
@test @inferred(mean(d_product)) == mean.(ds_vec)
156+
@test @inferred(std(d_product)) == std.(ds_vec)
153157
@test @inferred(var(d_product)) == var.(ds_vec)
154158
@test @inferred(cov(d_product)) == Diagonal(var.(ds_vec))
155159
@test @inferred(entropy(d_product)) == sum(entropy.(ds_vec))

0 commit comments

Comments
 (0)