Skip to content

Commit 6dc87a0

Browse files
committed
Major refactoring, breaking changes
1 parent 9f4fe67 commit 6dc87a0

File tree

6 files changed

+256
-153
lines changed

6 files changed

+256
-153
lines changed

src/EmpiricalDistributions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using Distributions
1313
using StatsBase
1414

1515

16+
include("hist_funcs.jl")
1617
include("uv_binned_dist.jl")
1718
include("mv_binned_dist.jl")
1819

src/hist_funcs.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# This file is a part of EmpiricalDistributions.jl, licensed under the MIT License (MIT).
2+
3+
4+
function _pdf(h::Histogram{T,N}, xs::NTuple{N,Real}) where {T,N}
5+
@assert h.isdensity # Implementation requires normalized histogram
6+
7+
idx = StatsBase.binindex(h, xs)
8+
r::T = zero(T)
9+
if checkbounds(Bool, h.weights, idx...)
10+
@inbounds r = h.weights[idx...]
11+
end
12+
r
13+
end
14+
15+
16+
function _mean(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64) where {N}
17+
@assert !h.isdensity # Implementation currently assumes non-normalized histogram
18+
19+
s_inv::T = inv(sum(h.weights))
20+
m::Vector{T} = zeros(T, N)
21+
mps = StatsBase.midpoints.(h.edges)
22+
cart_inds = CartesianIndices(h.weights)
23+
for i in cart_inds
24+
for idim in 1:N
25+
m[idim] += s_inv * mps[idim][i[idim]] * h.weights[i]
26+
end
27+
end
28+
return m
29+
end
30+
31+
32+
_findmaxidx_tuple_or_int(A::AbstractVector{<:Real}) = findmax(A)[2]
33+
_findmaxidx_tuple_or_int(A::AbstractArray{<:Real}) = findmax(A)[2].I
34+
35+
function _mode(h::StatsBase.Histogram; T::DataType = Float64)
36+
@assert h.isdensity # Implementation requires normalized histogram
37+
38+
maxidx = _findmaxidx_tuple_or_int(h.weights)
39+
mode_corner1 = map(getindex, h.edges, maxidx)
40+
mode_corner2 = map(getindex, h.edges, maxidx .+ 1)
41+
cov_est = T[(mode_corner1 .+ mode_corner2) ./ 2...]
42+
end
43+
44+
45+
function _var(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T), ) where {N}
46+
@assert !h.isdensity # Implementation currently assumes non-normalized histogram
47+
48+
s_inv::T = inv(sum(h.weights))
49+
v::Vector{T} = zeros(T, N)
50+
mps = StatsBase.midpoints.(h.edges)
51+
cart_inds = CartesianIndices(h.weights)
52+
for i in cart_inds
53+
for idim in 1:N
54+
v[idim] += s_inv * (mps[idim][i[idim]] - mean[idim])^2 * h.weights[i]
55+
end
56+
end
57+
return v
58+
end
59+
60+
61+
function _cov(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T)) where {N}
62+
@assert !h.isdensity # Implementation currently assumes non-normalized histogram
63+
64+
s_inv::T = inv(sum(h.weights))
65+
c::Matrix{T} = zeros(T, N, N)
66+
mps = StatsBase.midpoints.(h.edges)
67+
cart_inds = CartesianIndices(h.weights)
68+
for i in cart_inds
69+
for idim in 1:N
70+
for jdim in 1:N
71+
c[idim, jdim] += s_inv * (mps[idim][i[idim]] - mean[idim]) * (mps[jdim][i[jdim]] - mean[jdim]) * h.weights[i]
72+
end
73+
end
74+
end
75+
return c
76+
end

src/mv_binned_dist.jl

Lines changed: 51 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
"""
5-
UvBinnedDist <: Distribution{Univariate,Continuous}
5+
MvBinnedDist <: Distribution{Multivariate,Continuous}
66
77
Wraps a multi-dimensional histograms and presents it as a binned multivariate
88
distribution.
@@ -11,16 +11,21 @@ Constructor:
1111
1212
MvBinnedDist(h::Histogram{<:Real,N})
1313
"""
14-
struct MvBinnedDist{T, N} <: Distributions.Distribution{Multivariate,Continuous}
15-
h::StatsBase.Histogram{<:Real, N}
16-
edges::NTuple{N, <:AbstractVector{T}}
17-
cart_inds::CartesianIndices{N, NTuple{N, Base.OneTo{Int}}}
18-
19-
probabilty_edges::AbstractVector{T}
20-
21-
μ::AbstractVector{T}
22-
var::AbstractVector{T}
23-
cov::AbstractMatrix{T}
14+
struct MvBinnedDist{
15+
T <: Real,
16+
N,
17+
H <: Histogram{<:Real, N},
18+
VT <: AbstractVector{T},
19+
MT <: AbstractMatrix{T}
20+
} <: Distributions.Distribution{Multivariate,Continuous}
21+
hist::H
22+
_edges::NTuple{N, <:AbstractVector{T}}
23+
_cart_inds::CartesianIndices{N, NTuple{N, Base.OneTo{Int}}}
24+
_probability_edges::VT
25+
_mean::VT
26+
_mode::VT
27+
_var::VT
28+
_cov::MT
2429
end
2530

2631
export MvBinnedDist
@@ -37,83 +42,45 @@ function MvBinnedDist(h::StatsBase.Histogram{<:Real, N}, T::DataType = Float64)
3742
probabilty_edges[i+1] = v > 1 ? 1 : v
3843
end
3944

40-
mean = _mean(h)
41-
var = _var(h, mean = mean)
42-
cov = _cov(h, mean = mean)
45+
mean_est = _mean(h)
46+
mode_est = _mode(nh)
47+
var_est = _var(h, mean = mean_est)
48+
cov_est = _cov(h, mean = mean_est)
4349

44-
return MvBinnedDist{T, N}(
50+
return MvBinnedDist(
4551
nh,
4652
collect.(nh.edges),
4753
CartesianIndices(nh.weights),
4854
probabilty_edges,
49-
mean,
50-
var,
51-
cov
55+
mean_est,
56+
mode_est,
57+
var_est,
58+
cov_est
5259
)
5360
end
5461

5562

63+
Base.convert(::Type{Histogram}, d::MvBinnedDist) = d.hist
64+
65+
5666
Base.length(d::MvBinnedDist{T, N}) where {T, N} = N
5767
Base.size(d::MvBinnedDist{T, N}) where {T, N} = (N,)
5868
Base.eltype(d::MvBinnedDist{T, N}) where {T, N} = T
5969

60-
Statistics.mean(d::MvBinnedDist{T, N}) where {T, N} = d.μ
61-
Statistics.var(d::MvBinnedDist{T, N}) where {T, N} = d.var
62-
Statistics.cov(d::MvBinnedDist{T, N}) where {T, N} = d.cov
63-
64-
65-
function _mean(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64) where {N}
66-
s_inv::T = inv(sum(h.weights))
67-
m::Vector{T} = zeros(T, N)
68-
mps = StatsBase.midpoints.(h.edges)
69-
cart_inds = CartesianIndices(h.weights)
70-
for i in cart_inds
71-
for idim in 1:N
72-
m[idim] += s_inv * mps[idim][i[idim]] * h.weights[i]
73-
end
74-
end
75-
return m
76-
end
77-
78-
79-
function _var(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T), ) where {N}
80-
s_inv::T = inv(sum(h.weights))
81-
v::Vector{T} = zeros(T, N)
82-
mps = StatsBase.midpoints.(h.edges)
83-
cart_inds = CartesianIndices(h.weights)
84-
for i in cart_inds
85-
for idim in 1:N
86-
v[idim] += s_inv * (mps[idim][i[idim]] - mean[idim])^2 * h.weights[i]
87-
end
88-
end
89-
return v
90-
end
91-
92-
93-
function _cov(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T)) where {N}
94-
s_inv::T = inv(sum(h.weights))
95-
c::Matrix{T} = zeros(T, N, N)
96-
mps = StatsBase.midpoints.(h.edges)
97-
cart_inds = CartesianIndices(h.weights)
98-
for i in cart_inds
99-
for idim in 1:N
100-
for jdim in 1:N
101-
c[idim, jdim] += s_inv * (mps[idim][i[idim]] - mean[idim]) * (mps[jdim][i[jdim]] - mean[jdim]) * h.weights[i]
102-
end
103-
end
104-
end
105-
return c
106-
end
70+
Statistics.mean(d::MvBinnedDist{T, N}) where {T, N} = d._mean
71+
StatsBase.mode(d::MvBinnedDist{T, N}) where {T, N} = d._mode
72+
Statistics.var(d::MvBinnedDist{T, N}) where {T, N} = d._var
73+
Statistics.cov(d::MvBinnedDist{T, N}) where {T, N} = d._cov
10774

10875

10976
function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractVector{<:Real}) where {T, N}
11077
rand!(r, A)
111-
next_inds::UnitRange{Int} = searchsorted(d.probabilty_edges::Vector{T}, A[1]::T)
78+
next_inds::UnitRange{Int} = searchsorted(d._probability_edges::Vector{T}, A[1]::T)
11279
cell_lin_index::Int = min(next_inds.start, next_inds.stop)
113-
cell_car_index = d.cart_inds[cell_lin_index]
80+
cell_car_index = d._cart_inds[cell_lin_index]
11481
for idim in Base.OneTo(N)
11582
i = cell_car_index[idim]
116-
sub_int = d.edges[idim][i:i+1]
83+
sub_int = d._edges[idim][i:i+1]
11784
sub_int_width::T = sub_int[2] - sub_int[1]
11885
A[idim] = sub_int[1] + sub_int_width * A[idim]
11986
end
@@ -122,11 +89,25 @@ end
12289

12390
function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractMatrix{<:Real}) where {T, N}
12491
Distributions._rand!.((r,), (d,), nestedview(A))
92+
return A
93+
end
94+
95+
96+
# Similar to unroll_tuple in StaticArrays.jl:
97+
@generated function _unsafe_unroll_tuple(A::AbstractArray, ::Val{L}) where {L}
98+
exprs = [:(A[idx0 + $j]) for j = 0:(L-1)]
99+
quote
100+
idx0 = firstindex(A)
101+
Base.@_inline_meta
102+
@inbounds return $(Expr(:tuple, exprs...))
103+
end
125104
end
126105

127106

128-
function Distributions.pdf(d::MvBinnedDist{T, N}, x::AbstractArray{<:Real, 1}) where {T, N}
129-
return @inbounds d.h.weights[StatsBase.binindex(d.h, Tuple(x))...]
107+
function Distributions.pdf(d::MvBinnedDist{T,N}, x::AbstractVector{<:Real}) where {T,N}
108+
length(eachindex(x)) == N || throw(ArgumentError("Length of variate doesn't match dimensionality of distribution"))
109+
x_tpl = _unsafe_unroll_tuple(x, Val(N))
110+
_pdf(d.hist, x_tpl)
130111
end
131112

132113

0 commit comments

Comments
 (0)