Skip to content

Commit 149a279

Browse files
committed
Add UvBinnedDist and MvBinnedDist
Based on original implementation by Lukas Hauertmann <lhauert@mpp.mpg.de> in BAT.jl.
1 parent d27aa9a commit 149a279

File tree

7 files changed

+313
-1
lines changed

7 files changed

+313
-1
lines changed

LICENSE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ The EmpiricalDistributions.jl package is licensed under the MIT "Expat" License:
22

33
> Copyright (c) 2018:
44
>
5+
> Lukas Hauertmann <lhauert@mpp.mpg.de>,
56
> Oliver Schulz <oschulz@mpp.mpg.de>
67
>
78
> Permission is hereby granted, free of charge, to any person obtaining a copy

src/EmpiricalDistributions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@ using Distributions
1313
using StatsBase
1414

1515

16+
include("uv_binned_dist.jl")
17+
include("mv_binned_dist.jl")
18+
1619
end # module

src/mv_binned_dist.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# This file is a part of EmpiricalDistributions.jl, licensed under the MIT License (MIT).
2+
3+
struct MvBinnedDist{T, N} <: Distributions.Distribution{Multivariate,Continuous}
4+
h::StatsBase.Histogram{<:Real, N}
5+
edges::NTuple{N, <:AbstractVector{T}}
6+
cart_inds::CartesianIndices{N, NTuple{N, Base.OneTo{Int}}}
7+
8+
probabilty_edges::AbstractVector{T}
9+
10+
μ::AbstractVector{T}
11+
var::AbstractVector{T}
12+
cov::AbstractMatrix{T}
13+
end
14+
15+
export MvBinnedDist
16+
17+
18+
function MvBinnedDist(h::StatsBase.Histogram{<:Real, N}, T::DataType = Float64) where {N}
19+
nh = normalize(h)
20+
21+
probabilty_widths = nh.weights * inv(sum(nh.weights))
22+
probabilty_edges::Vector{T} = Vector{Float64}(undef, length(h.weights) + 1)
23+
probabilty_edges[1] = 0
24+
for (i, w) in enumerate(probabilty_widths)
25+
v = probabilty_edges[i] + probabilty_widths[i]
26+
probabilty_edges[i+1] = v > 1 ? 1 : v
27+
end
28+
29+
mean = _mean(h)
30+
var = _var(h, mean = mean)
31+
cov = _cov(h, mean = mean)
32+
33+
return MvBinnedDist{T, N}(
34+
nh,
35+
collect.(nh.edges),
36+
CartesianIndices(nh.weights),
37+
probabilty_edges,
38+
mean,
39+
var,
40+
cov
41+
)
42+
end
43+
44+
45+
Base.length(d::MvBinnedDist{T, N}) where {T, N} = N
46+
Base.size(d::MvBinnedDist{T, N}) where {T, N} = (N,)
47+
Base.eltype(d::MvBinnedDist{T, N}) where {T, N} = T
48+
49+
Statistics.mean(d::MvBinnedDist{T, N}) where {T, N} = d.μ
50+
Statistics.var(d::MvBinnedDist{T, N}) where {T, N} = d.var
51+
Statistics.cov(d::MvBinnedDist{T, N}) where {T, N} = d.cov
52+
53+
54+
function _mean(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64) where {N}
55+
s_inv::T = inv(sum(h.weights))
56+
m::Vector{T} = zeros(T, N)
57+
mps = StatsBase.midpoints.(h.edges)
58+
cart_inds = CartesianIndices(h.weights)
59+
for i in cart_inds
60+
for idim in 1:N
61+
m[idim] += s_inv * mps[idim][i[idim]] * h.weights[i]
62+
end
63+
end
64+
return m
65+
end
66+
67+
68+
function _var(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T), ) where {N}
69+
s_inv::T = inv(sum(h.weights))
70+
v::Vector{T} = zeros(T, N)
71+
mps = StatsBase.midpoints.(h.edges)
72+
cart_inds = CartesianIndices(h.weights)
73+
for i in cart_inds
74+
for idim in 1:N
75+
v[idim] += s_inv * (mps[idim][i[idim]] - mean[idim])^2 * h.weights[i]
76+
end
77+
end
78+
return v
79+
end
80+
81+
82+
function _cov(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T)) where {N}
83+
s_inv::T = inv(sum(h.weights))
84+
c::Matrix{T} = zeros(T, N, N)
85+
mps = StatsBase.midpoints.(h.edges)
86+
cart_inds = CartesianIndices(h.weights)
87+
for i in cart_inds
88+
for idim in 1:N
89+
for jdim in 1:N
90+
c[idim, jdim] += s_inv * (mps[idim][i[idim]] - mean[idim]) * (mps[jdim][i[jdim]] - mean[jdim]) * h.weights[i]
91+
end
92+
end
93+
end
94+
return c
95+
end
96+
97+
98+
function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractVector{<:Real}) where {T, N}
99+
rand!(r, A)
100+
next_inds::UnitRange{Int} = searchsorted(d.probabilty_edges::Vector{T}, A[1]::T)
101+
cell_lin_index::Int = min(next_inds.start, next_inds.stop)
102+
cell_car_index = d.cart_inds[cell_lin_index]
103+
for idim in Base.OneTo(N)
104+
i = cell_car_index[idim]
105+
sub_int = d.edges[idim][i:i+1]
106+
sub_int_width::T = sub_int[2] - sub_int[1]
107+
A[idim] = sub_int[1] + sub_int_width * A[idim]
108+
end
109+
return A
110+
end
111+
112+
function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractMatrix{<:Real}) where {T, N}
113+
Distributions._rand!.((r,), (d,), nestedview(A))
114+
end
115+
116+
117+
function Distributions.pdf(d::MvBinnedDist{T, N}, x::AbstractArray{<:Real, 1}) where {T, N}
118+
return @inbounds d.h.weights[StatsBase.binindex(d.h, Tuple(x))...]
119+
end
120+
121+
122+
function Distributions.logpdf(d::MvBinnedDist{T, N}, x::AbstractArray{<:Real, 1}) where {T, N}
123+
return log(pdf(d, x))
124+
end
125+
126+
function Distributions._logpdf(d::MvBinnedDist{T,N}, x::AbstractArray{<:Real, 1}) where {T, N}
127+
return logpdf(d, x)
128+
end

src/uv_binned_dist.jl

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# This file is a part of EmpiricalDistributions.jl, licensed under the MIT License (MIT).
2+
3+
4+
struct UvBinnedDist{T <: AbstractFloat} <: Distribution{Univariate,Continuous}
5+
h::Histogram{<:Real, 1}
6+
inv_weights::Vector{T}
7+
edges::Vector{T}
8+
volumes::Vector{T}
9+
10+
probabilty_edges::Vector{T}
11+
probabilty_volumes::Vector{T}
12+
probabilty_inv_volumes::Vector{T}
13+
14+
acc_prob::Vector{T}
15+
16+
μ::T
17+
var::T
18+
cov::Matrix{T}
19+
σ::T
20+
end
21+
22+
export UvBinnedDist
23+
24+
25+
function UvBinnedDist(h::Histogram{<:Real, 1}, T::DataType = Float64)
26+
nh = normalize(h)
27+
probabilty_widths::Vector{T} = h.weights * inv(sum(h.weights))
28+
probabilty_edges::Vector{T} = Vector{Float64}(undef, length(probabilty_widths) + 1)
29+
probabilty_edges[1] = 0
30+
@inbounds for (i, w) in enumerate(probabilty_widths)
31+
probabilty_edges[i+1] = probabilty_edges[i] + probabilty_widths[i]
32+
end
33+
probabilty_edges[end] = 1
34+
volumes = diff(h.edges[1])
35+
mean = Statistics.mean(StatsBase.midpoints(nh.edges[1]), ProbabilityWeights(nh.weights))
36+
var = Statistics.var(StatsBase.midpoints(nh.edges[1]), ProbabilityWeights(nh.weights), mean = mean)
37+
38+
acc_prob::Vector{T} = zeros(T, length(nh.weights))
39+
for i in 2:length(acc_prob)
40+
acc_prob[i] += acc_prob[i-1] + nh.weights[i-1] * volumes[i-1]
41+
end
42+
43+
d::UvBinnedDist{T} = UvBinnedDist{T}(
44+
nh,
45+
inv.(nh.weights),
46+
nh.edges[1],
47+
volumes,
48+
probabilty_edges,
49+
probabilty_widths,
50+
inv.(probabilty_widths),
51+
acc_prob,
52+
mean,
53+
var,
54+
fill(var, 1, 1),
55+
sqrt(var)
56+
)
57+
end
58+
59+
60+
function Random.rand(rng::AbstractRNG, d::UvBinnedDist{T})::T where {T <: AbstractFloat}
61+
r::T = rand()
62+
next_inds::UnitRange{Int} = searchsorted(d.probabilty_edges, r)
63+
next_ind_l::Int = next_inds.start
64+
next_ind_r::Int = next_inds.stop
65+
if next_ind_l > next_ind_r
66+
next_ind_l = next_inds.stop
67+
next_ind_r = next_inds.start
68+
end
69+
ret::T = d.edges[next_ind_l]
70+
if next_ind_l < next_ind_r
71+
ret += d.volumes[next_ind_l] * (d.probabilty_edges[next_ind_r] - r) * d.probabilty_inv_volumes[next_ind_l]
72+
end
73+
return ret
74+
end
75+
76+
77+
function Distributions.pdf(d::UvBinnedDist{T}, x::Real)::T where {T <: AbstractFloat}
78+
i::Int = StatsBase.binindex(d.h, x)
79+
return @inbounds d.h.weights[i]
80+
end
81+
82+
83+
function Distributions.logpdf(d::UvBinnedDist{T}, x::Real)::T where {T <: AbstractFloat}
84+
return log(pdf(d, x))
85+
end
86+
87+
88+
function Distributions.cdf(d::UvBinnedDist{T}, x::Real)::T where {T <: AbstractFloat}
89+
i::Int = StatsBase.binindex(d.h, x)
90+
p::T = @inbounds sum(d.h.weights[1:i-1] .* d.volumes[1:i-1])
91+
p += (x - d.edges[i]) * d.h.weights[i]
92+
return p
93+
end
94+
95+
96+
function Distributions.minimum(d::UvBinnedDist{T})::T where {T <: AbstractFloat}
97+
d.edges[1]
98+
end
99+
100+
function Distributions.maximum(d::UvBinnedDist{T})::T where {T <: AbstractFloat}
101+
d.edges[end]
102+
end
103+
104+
105+
function Distributions.insupport(d::UvBinnedDist{T}, x::Real)::Bool where {T <: AbstractFloat}
106+
d.edges[1] <= x <= d.edges[end]
107+
end
108+
109+
110+
function Distributions.quantile(d::UvBinnedDist{T}, x::Real)::T where {T <: AbstractFloat}
111+
r::UnitRange{Int} = searchsorted(d.acc_prob, x)
112+
idx::Int = min(r.start, r.stop)
113+
p::T = d.acc_prob[ idx ]
114+
q::T = d.edges[idx]
115+
missing_p::T = x - p
116+
inv_weight::T = d.inv_weights[idx]
117+
if !isinf(inv_weight)
118+
q += missing_p * inv_weight
119+
end
120+
return min(q, maximum(d))
121+
end
122+
123+
124+
Base.eltype(d::UvBinnedDist{T}) where {T <: AbstractFloat}= T
125+
126+
127+
Statistics.mean(d::UvBinnedDist) = d.μ
128+
129+
Statistics.var(d::UvBinnedDist) = d.var
130+
131+
Statistics.cov(d::UvBinnedDist) = d.cov

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import Test
44
Test.@testset "Package EmpiricalDistributions" begin
55

6-
# test code goes here
6+
include("test_uv_binned_dist.jl")
7+
include("test_mv_binned_dist.jl")
78

89
end # testset

test/test_mv_binned_dist.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# This file is a part of EmpiricalDistributions.jl, licensed under the MIT License (MIT).
2+
3+
using EmpiricalDistributions
4+
using Test
5+
6+
using Random
7+
using Distributions, StatsBase, LinearAlgebra
8+
9+
10+
@testset "mv_binned_dist" begin
11+
Random.seed!(123)
12+
μ = [1.23, -0.67]
13+
Σ = [0.45 0.32; 0.32 0.76]' * [0.45 0.32; 0.32 0.76]
14+
true_dist = MvNormal(μ, Σ)
15+
h = Histogram((μ[1]-10Σ[1]:Σ[1]/10:μ[1]+10Σ[1], μ[2]-10Σ[4]:Σ[4]/10:μ[2]+10Σ[4]))
16+
n = 10^6
17+
r = rand(true_dist, n)
18+
append!(h, (r[1, :], r[2, :]))
19+
d = MvBinnedDist(h)
20+
@test all(isapprox.(μ, d.μ, atol = 0.01))
21+
@test all(isapprox.(Σ, d.cov, atol = 0.01))
22+
rand!(d, r)
23+
fit_dist = fit(MvNormal, r)
24+
@test all(isapprox.(μ, fit_dist.μ, atol = 0.01))
25+
@test all(isapprox.(Σ, fit_dist.Σ.mat, atol = 0.01))
26+
end

test/test_uv_binned_dist.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# This file is a part of EmpiricalDistributions.jl, licensed under the MIT License (MIT).
2+
3+
using EmpiricalDistributions
4+
using Test
5+
6+
using Random
7+
using Distributions, StatsBase
8+
9+
10+
@testset "uv_binned_dist" begin
11+
Random.seed!(123)
12+
μ, σ = 1.23, 0.74
13+
true_dist = Normal(μ, σ)
14+
h = Histogram-10σ/10:μ+10σ)
15+
append!(h, rand(true_dist, 10^7))
16+
d = UvBinnedDist(h)
17+
@test isapprox(μ, d.μ, atol = 0.01)
18+
@test isapprox(σ, d.σ, atol = 0.01)
19+
fit_dist = fit(Normal, rand(d, 10^7))
20+
@test isapprox(μ, fit_dist.μ, atol = 0.01)
21+
@test isapprox(σ, fit_dist.σ, atol = 0.01)
22+
end

0 commit comments

Comments
 (0)