Skip to content

Commit 99d6b25

Browse files
committed
Substantial refactoring and cleanup, type parameters change
1 parent 51f2a53 commit 99d6b25

File tree

10 files changed

+374
-241
lines changed

10 files changed

+374
-241
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "0bbb1fad-0f24-45fe-94a4-415852c5cc3b"
33
version = "0.2.1"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -11,13 +12,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1213

1314
[compat]
15+
Adapt = "1.0, 2.0"
1416
ArraysOfArrays = "0.4, 0.5"
1517
Distributions = "0.21.3, 0.22, 0.23, 0.24"
1618
StatsBase = "0.32, 0.33"
1719
julia = "1"
1820

1921
[extras]
22+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2023
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2124

2225
[targets]
23-
test = ["Test"]
26+
test = ["ForwardDiff", "Test"]

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
[![Codecov](https://codecov.io/gh/oschulz/EmpiricalDistributions.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/oschulz/EmpiricalDistributions.jl)
88

99
A Julia package for empirical probability distributions. Currently
10-
implements uni- and multivariate binned distributions, backed by
11-
[StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl) histograms.
10+
implements uni- and multivariate binned distributions that can be created
11+
from [StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl) histograms.
1212

1313

1414
## Documentation

docs/src/index.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
A Julia package for empirical probability distributions that implement the
44
[Distributions.jl](https://github.com/JuliaStats/Distributions.jl) API.
55

6-
This package currently provides uni- and multivariate binned distributions,
7-
backed by [StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl)
8-
histograms.
6+
This package currently provides uni- and multivariate binned distributions
7+
that can be created from
8+
[StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl) histograms.
99

10-
[`UvBinnedDist`](@ref) wraps a 1-dimensional histogram and presents it as
11-
a (binned) univariate continuous distribution:
10+
[`UvBinnedDist`](@ref), usually derived from a 1-dimensional histogram,
11+
represents a binned univariate continuous distribution:
1212

1313
```julia
1414
using Distributions, StatsBase
@@ -31,15 +31,21 @@ maximum(uvdist), minimum(uvdist)
3131
rand(uvdist, 5)
3232
```
3333

34-
[`MvBinnedDist`](@ref) does the same for a multi-dimensional histogram,
35-
and presents it as a (binned) multivariate continuous distribution:
34+
A binned distribution can be converted back to a histogram:
35+
36+
```julia
37+
using LinearAlgebra
38+
Histogram(uvdist) == normalize(uvhist)
39+
```
40+
41+
42+
[`MvBinnedDist`](@ref), usually derived from a multidimensional histogram,
43+
represents a binned multivariate continuous distribution:
3644

3745
```julia
3846
X_mv = rand(MvNormal([3.5, 0.5], [2.0 0.5; 0.5 1.0]), 10^5)
3947
mvhist = fit(Histogram, (X_mv[1, :], X_mv[2, :]))
4048

41-
using Distributions, EmpiricalDistributions
42-
4349
mvdist = MvBinnedDist(mvhist)
4450
mvdist isa Distribution{Multivariate,Continuous}
4551

src/EmpiricalDistributions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ using LinearAlgebra
88
using Random
99
using Statistics
1010

11+
using Adapt
1112
using ArraysOfArrays
1213
using Distributions
1314
using StatsBase
1415

1516

16-
include("hist_funcs.jl")
17+
include("utils.jl")
1718
include("uv_binned_dist.jl")
1819
include("mv_binned_dist.jl")
1920

src/hist_funcs.jl

Lines changed: 0 additions & 76 deletions
This file was deleted.

src/mv_binned_dist.jl

Lines changed: 86 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,38 @@
22

33

44
"""
5-
MvBinnedDist <: Distribution{Multivariate,Continuous}
5+
MvBinnedDist <: StatsBase.Distribution{Multivariate,Continuous}
66
7-
Wraps a multi-dimensional histograms and presents it as a binned multivariate
8-
distribution.
7+
A binned multivariate distribution, usually derived from a histogram.
98
10-
Constructor:
9+
Constructors:
1110
12-
MvBinnedDist(h::Histogram{<:Real,N})
11+
```julia
12+
MvBinnedDist(h::StatsBase.Histogram{<:Real,N})
13+
MvBinnedDist{T<:Real}(h::StatsBase.Histogram{<:Real,N})
14+
```
15+
16+
You can convert a `MvBinnedDist` back to a histogram via
17+
18+
```julia
19+
convert(StatsBase.Histogram, dist::MvBinnedDist)
20+
StatsBase.Histogram(dist::MvBinnedDist)
21+
```
1322
"""
1423
struct MvBinnedDist{
1524
T <: Real,
1625
N,
17-
H <: Histogram{<:Real, N},
26+
U <: Real,
27+
ET <: NTuple{N,AbstractVector{<:Real}},
1828
VT <: AbstractVector{T},
19-
MT <: AbstractMatrix{T}
29+
MT <: AbstractMatrix{T},
30+
VU <: AbstractVector{U},
31+
AU <: AbstractArray{U,N}
2032
} <: Distributions.Distribution{Multivariate,Continuous}
21-
hist::H
22-
_edges::NTuple{N, <:AbstractVector{T}}
23-
_cart_inds::CartesianIndices{N, NTuple{N, Base.OneTo{Int}}}
24-
_probability_edges::VT
33+
_edges::ET
34+
_bin_pdf::AU
35+
_bin_linidx_cdf::VU
36+
_closed_left::Bool
2537
_mean::VT
2638
_mode::VT
2739
_var::VT
@@ -31,64 +43,75 @@ end
3143
export MvBinnedDist
3244

3345

34-
function MvBinnedDist(h::StatsBase.Histogram{<:Real, N}, T::DataType = Float64) where {N}
46+
function MvBinnedDist{T}(h::Histogram{<:Real}) where {T<:Real}
3547
nh = normalize(h)
3648

37-
probabilty_widths = nh.weights * inv(sum(nh.weights))
38-
probabilty_edges::Vector{T} = Vector{Float64}(undef, length(h.weights) + 1)
39-
probabilty_edges[1] = 0
40-
for (i, w) in enumerate(probabilty_widths)
41-
v = probabilty_edges[i] + probabilty_widths[i]
42-
probabilty_edges[i+1] = v > 1 ? 1 : v
43-
end
49+
edges = nh.edges
50+
bin_pdf = nh.weights
51+
52+
closed_left = nh.closed == :left
53+
54+
Y = nh.weights
55+
X = _bin_centers.(nh.edges)
56+
W = _bin_widths.(nh.edges)
57+
58+
bin_linidx_cdf = cumsum(broadcast(idx -> Y[idx] .* prod(map(getindex, W, idx.I)), vec(CartesianIndices(Y))))
59+
@assert last(bin_linidx_cdf) 1
60+
bin_linidx_cdf[end] = 1
4461

45-
mean_est = _mean(h)
46-
mode_est = _mode(nh)
47-
var_est = _var(h, mean = mean_est)
48-
cov_est = _cov(h, mean = mean_est)
62+
mean_est_tpl = _mean(nh)
63+
mean_est = [T.(mean_est_tpl)...]
64+
mode_est = [T.(_mode(nh))...]
65+
var_est = [T.(_var(nh, mean_est_tpl))...]
66+
cov_est = T.(_cov(nh, mean_est_tpl))
4967

5068
return MvBinnedDist(
51-
nh,
52-
collect.(nh.edges),
53-
CartesianIndices(nh.weights),
54-
probabilty_edges,
55-
mean_est,
56-
mode_est,
57-
var_est,
58-
cov_est
69+
edges, bin_pdf, bin_linidx_cdf, closed_left,
70+
mean_est, mode_est, var_est, cov_est
5971
)
6072
end
6173

74+
MvBinnedDist(h::Histogram{<:Real}) = MvBinnedDist{float(promote_type(map(eltype, h.edges)...))}(h)
6275

63-
Base.convert(::Type{Histogram}, d::MvBinnedDist) = d.hist
76+
77+
function Adapt.adapt_structure(to, d::MvBinnedDist)
78+
MvBinnedDist(
79+
map(e -> adapt(to, e), d._edges), adapt(to, d._bin_pdf), adapt(to, d._bin_linidx_cdf),
80+
adapt(to, d._closed_left), adapt(to, d._mean), adapt(to, d._mode), adapt(to, d._var), adapt(to, d._cov)
81+
)
82+
end
6483

6584

66-
Base.length(d::MvBinnedDist{T, N}) where {T, N} = N
67-
Base.size(d::MvBinnedDist{T, N}) where {T, N} = (N,)
68-
Base.eltype(d::MvBinnedDist{T, N}) where {T, N} = T
85+
Histogram(d::MvBinnedDist) = Histogram(map(Array, d._edges), Array(d._bin_pdf), (d._closed_left ? :left : :right), true)
86+
Base.convert(::Type{Histogram}, d::MvBinnedDist) = Histogram(d)
6987

70-
Statistics.mean(d::MvBinnedDist{T, N}) where {T, N} = d._mean
71-
StatsBase.mode(d::MvBinnedDist{T, N}) where {T, N} = d._mode
72-
Statistics.var(d::MvBinnedDist{T, N}) where {T, N} = d._var
73-
Statistics.cov(d::MvBinnedDist{T, N}) where {T, N} = d._cov
7488

89+
Base.length(d::MvBinnedDist{T,N}) where {T,N} = N
90+
Base.size(d::MvBinnedDist{T,N}) where {T,N} = (N,)
91+
Base.eltype(d::MvBinnedDist{T,N}) where {T,N} = T
7592

76-
function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractVector{<:Real}) where {T, N}
77-
rand!(r, A)
78-
next_inds::UnitRange{Int} = searchsorted(d._probability_edges::Vector{T}, A[1]::T)
79-
cell_lin_index::Int = min(next_inds.start, next_inds.stop)
80-
cell_car_index = d._cart_inds[cell_lin_index]
81-
for idim in Base.OneTo(N)
82-
i = cell_car_index[idim]
83-
sub_int = d._edges[idim][i:i+1]
84-
sub_int_width::T = sub_int[2] - sub_int[1]
85-
A[idim] = sub_int[1] + sub_int_width * A[idim]
93+
Statistics.mean(d::MvBinnedDist) = d._mean
94+
StatsBase.mode(d::MvBinnedDist) = d._mode
95+
Statistics.var(d::MvBinnedDist) = d._var
96+
Statistics.cov(d::MvBinnedDist) = d._cov
97+
98+
99+
function Distributions._rand!(rng::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractVector{<:Real}) where {T,N}
100+
@assert length(eachindex(A)) == N
101+
u = rand(rng)
102+
i = searchsortedfirst(d._bin_linidx_cdf, u)
103+
idx_lo = CartesianIndices(d._bin_pdf)[i]
104+
idx_hi = idx_lo + CartesianIndex(1, 1)
105+
x_lo = map(getindex, d._edges, idx_lo.I)
106+
x_hi = map(getindex, d._edges, idx_hi.I)
107+
for i in 1:N
108+
A[i] = _rand_uniform(rng, T, x_lo[i], x_hi[i])
86109
end
87110
return A
88111
end
89112

90-
function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractMatrix{<:Real}) where {T, N}
91-
Distributions._rand!.((r,), (d,), nestedview(A))
113+
function Distributions._rand!(rng::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractMatrix{<:Real}) where {T,N}
114+
Distributions._rand!.(Ref(rng), (d,), nestedview(A))
92115
return A
93116
end
94117

@@ -104,17 +127,24 @@ end
104127
end
105128

106129

107-
function Distributions.pdf(d::MvBinnedDist{T,N}, x::AbstractVector{<:Real}) where {T,N}
130+
function Distributions.pdf(d::MvBinnedDist{T,N,U}, x::AbstractVector{<:Real}) where {T,N,U}
108131
length(eachindex(x)) == N || throw(ArgumentError("Length of variate doesn't match dimensionality of distribution"))
109-
x_tpl = _unsafe_unroll_tuple(x, Val(N))
110-
_pdf(d.hist, x_tpl)
132+
xs = _unsafe_unroll_tuple(x, Val(N))
133+
134+
idxs = _find_bin(d._edges, d._closed_left, xs)
135+
if checkbounds(Bool, d._bin_pdf, idxs...)
136+
@inbounds r = d._bin_pdf[idxs...]
137+
convert(U, r)
138+
else
139+
zero(U)
140+
end
111141
end
112142

113143

114-
function Distributions.logpdf(d::MvBinnedDist{T, N}, x::AbstractArray{<:Real, 1}) where {T, N}
144+
function Distributions.logpdf(d::MvBinnedDist{T,N}, x::AbstractArray{<:Real, 1}) where {T,N}
115145
return log(pdf(d, x))
116146
end
117147

118-
function Distributions._logpdf(d::MvBinnedDist{T,N}, x::AbstractArray{<:Real, 1}) where {T, N}
148+
function Distributions._logpdf(d::MvBinnedDist{T,N}, x::AbstractArray{<:Real, 1}) where {T,N}
119149
return logpdf(d, x)
120150
end

0 commit comments

Comments
 (0)