Skip to content

Commit f889f9e

Browse files
authored
Respect type of parameters in fit for Bernoulli, Binomial, and Uniform (#1558)
* Respect parameter type in `fit` for `Bernoulli` and `Binomial` * Respect parameter type in `fit` for `Uniform`
1 parent a350622 commit f889f9e

File tree

5 files changed

+38
-32
lines changed

5 files changed

+38
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Distributions"
22
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
authors = ["JuliaStats"]
4-
version = "0.25.61"
4+
version = "0.25.62"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/univariate/continuous/uniform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ External links
2626
struct Uniform{T<:Real} <: ContinuousUnivariateDistribution
2727
a::T
2828
b::T
29-
Uniform{T}(a::T, b::T) where {T <: Real} = new{T}(a, b)
29+
Uniform{T}(a::Real, b::Real) where {T <: Real} = new{T}(a, b)
3030
end
3131

3232
function Uniform(a::T, b::T; check_args::Bool=true) where {T <: Real}
@@ -125,11 +125,11 @@ _rand!(rng::AbstractRNG, d::Uniform, A::AbstractArray{<:Real}) =
125125

126126
#### Fitting
127127

128-
function fit_mle(::Type{<:Uniform}, x::AbstractArray{<:Real})
128+
function fit_mle(::Type{T}, x::AbstractArray{<:Real}) where {T<:Uniform}
129129
if isempty(x)
130130
throw(ArgumentError("x cannot be empty."))
131131
end
132-
return Uniform(extrema(x)...)
132+
return T(extrema(x)...)
133133
end
134134

135135
# ChainRules definitions

src/univariate/discrete/bernoulli.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ External links:
2727
struct Bernoulli{T<:Real} <: DiscreteUnivariateDistribution
2828
p::T
2929

30-
Bernoulli{T}(p::T) where {T <: Real} = new{T}(p)
30+
Bernoulli{T}(p::Real) where {T <: Real} = new{T}(p)
3131
end
3232

3333
function Bernoulli(p::Real; check_args::Bool=true)
@@ -120,7 +120,7 @@ end
120120

121121
BernoulliStats(c0::Real, c1::Real) = BernoulliStats(promote(c0, c1)...)
122122

123-
fit_mle(::Type{<:Bernoulli}, ss::BernoulliStats) = Bernoulli(ss.cnt1 / (ss.cnt0 + ss.cnt1))
123+
fit_mle(::Type{T}, ss::BernoulliStats) where {T<:Bernoulli} = T(ss.cnt1 / (ss.cnt0 + ss.cnt1))
124124

125125
function suffstats(::Type{<:Bernoulli}, x::AbstractArray{<:Integer})
126126
c0 = c1 = 0

src/univariate/discrete/binomial.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,15 @@ end
187187

188188
const BinomData = Tuple{Int, AbstractArray}
189189

190-
suffstats(::Type{<:Binomial}, data::BinomData) = suffstats(Binomial, data...)
191-
suffstats(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = suffstats(Binomial, data..., w)
190+
suffstats(::Type{T}, data::BinomData) where {T<:Binomial} = suffstats(T, data...)
191+
suffstats(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomial} = suffstats(T, data..., w)
192192

193-
fit_mle(::Type{<:Binomial}, ss::BinomialStats) = Binomial(ss.n, ss.ns / (ss.ne * ss.n))
193+
fit_mle(::Type{T}, ss::BinomialStats) where {T<:Binomial} = T(ss.n, ss.ns / (ss.ne * ss.n))
194194

195-
fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer}) = fit_mle(Binomial, suffstats(Binomial, n, x))
196-
fit_mle(::Type{<:Binomial}, n::Integer, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real}) = fit_mle(Binomial, suffstats(Binomial, n, x, w))
197-
fit_mle(::Type{<:Binomial}, data::BinomData) = fit_mle(Binomial, suffstats(Binomial, data))
198-
fit_mle(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = fit_mle(Binomial, suffstats(Binomial, data, w))
195+
fit_mle(::Type{T}, n::Integer, x::AbstractArray{<:Integer}) where {T<:Binomial}= fit_mle(T, suffstats(T, n, x))
196+
fit_mle(::Type{T}, n::Integer, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, suffstats(T, n, x, w))
197+
fit_mle(::Type{T}, data::BinomData) where {T<:Binomial} = fit_mle(T, suffstats(T, data))
198+
fit_mle(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, suffstats(T, data, w))
199199

200-
fit(::Type{<:Binomial}, data::BinomData) = fit_mle(Binomial, data)
201-
fit(::Type{<:Binomial}, data::BinomData, w::AbstractArray{<:Real}) = fit_mle(Binomial, data, w)
200+
fit(::Type{T}, data::BinomData) where {T<:Binomial} = fit_mle(T, data)
201+
fit(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, data, w)

test/fit.jl

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ end
3434

3535

3636
@testset "Testing fit for Bernoulli" begin
37-
for rng in ((), (rng,)), D in (Bernoulli, Bernoulli{Float64})
37+
for rng in ((), (rng,)), D in (Bernoulli, Bernoulli{Float64}, Bernoulli{Float32})
3838
v = rand(rng..., n0)
39-
z = rand(rng..., D(0.7), n0)
39+
z = rand(rng..., Bernoulli(0.7), n0)
4040
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
4141
ss = @inferred suffstats(D, x)
4242
@test ss isa Distributions.BernoulliStats
@@ -57,7 +57,7 @@ end
5757
@test mean(d) sum(v[z .== 1]) / sum(v)
5858
end
5959

60-
z = rand(rng..., D(0.7), N)
60+
z = rand(rng..., Bernoulli(0.7), N)
6161
for x in (z, OffsetArray(z, -N ÷ 2))
6262
d = @inferred fit(D, x)
6363
@test d isa D
@@ -82,9 +82,9 @@ end
8282
end
8383

8484
@testset "Testing fit for Binomial" begin
85-
for rng in ((), (rng,)), D in (Binomial, Binomial{Float64})
85+
for rng in ((), (rng,)), D in (Binomial, Binomial{Float64}, Binomial{Float32})
8686
v = rand(rng..., n0)
87-
z = rand(rng..., D(100, 0.3), n0)
87+
z = rand(rng..., Binomial(100, 0.3), n0)
8888
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
8989
ss = @inferred suffstats(D, (100, x))
9090
@test ss isa Distributions.BinomialStats
@@ -109,7 +109,7 @@ end
109109
@test succprob(d) dot(z, v) / (sum(v) * 100)
110110
end
111111

112-
z = rand(rng..., D(100, 0.3), N)
112+
z = rand(rng..., Binomial(100, 0.3), N)
113113
for x in (z, OffsetArray(z, -N ÷ 2))
114114
d = @inferred fit(D, 100, x)
115115
@test d isa D
@@ -291,18 +291,24 @@ end
291291
end
292292

293293
@testset "Testing fit for Uniform" begin
294-
for func in funcs, dist in (Uniform, Uniform{Float64})
295-
x = func[2](dist(1.2, 5.8), n0)
296-
d = fit(dist, x)
297-
@test isa(d, dist)
298-
@test 1.2 <= minimum(d) <= maximum(d) <= 5.8
299-
@test minimum(d) == minimum(x)
300-
@test maximum(d) == maximum(x)
294+
for rng in ((), (rng,)), D in (Uniform, Uniform{Float64}, Uniform{Float32})
295+
z = rand(rng..., Uniform(1.2, 5.8), n0)
296+
for x in (z, OffsetArray(z, -n0 ÷ 2))
297+
d = fit(D, x)
298+
@test d isa D
299+
@test 1.2 <= minimum(d) <= maximum(d) <= 5.8
300+
@test minimum(d) == partype(d)(minimum(z))
301+
@test maximum(d) == partype(d)(maximum(z))
302+
end
301303

302-
d = fit(dist, func[2](dist(1.2, 5.8), N))
303-
@test 1.2 <= minimum(d) <= maximum(d) <= 5.8
304-
@test isapprox(minimum(d), 1.2, atol=0.02)
305-
@test isapprox(maximum(d), 5.8, atol=0.02)
304+
z = rand(rng..., Uniform(1.2, 5.8), N)
305+
for x in (z, OffsetArray(z, -N ÷ 2))
306+
d = fit(D, x)
307+
@test d isa D
308+
@test 1.2 <= minimum(d) <= maximum(d) <= 5.8
309+
@test minimum(d) 1.2 atol=0.02
310+
@test maximum(d) 5.8 atol=0.02
311+
end
306312
end
307313
end
308314

0 commit comments

Comments
 (0)