Skip to content

Commit 8bb4181

Browse files
bgctwdevmotion
andauthored
add partype method to lognormal and semicircle (#1773)
* add partype method to lognormal and semicircle * Update src/univariate/continuous/lognormal.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update src/univariate/continuous/semicircle.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * add tests on par_type and Float32 * adapt partype testing to special case of nothing parameters and fix warning on ambiguous global variable * breaking adapt tests of special cases: Int partype * LocationScale: do not promote T but promote eltype T with eltype(inner) but also promote partype T with partype(inner) * Update .gitignore Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * revert DiscreteUniform to non-parametric will be moved to its own pull-request * generalize check of partype for non-Real parameters e.g. Nothing in Truncated distribution * Update src/univariate/locationscale.jl D is not used inside the function any more -> can simpify Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Remove comments * Apply suggestions from code review --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent f33cc10 commit 8bb4181

File tree

8 files changed

+20
-7
lines changed

8 files changed

+20
-7
lines changed

src/univariate/continuous/lognormal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ stdlogx(d::LogNormal) = d.σ
6262
mean(d::LogNormal) = ((μ, σ) = params(d); exp+ σ^2/2))
6363
median(d::LogNormal) = exp(d.μ)
6464
mode(d::LogNormal) = ((μ, σ) = params(d); exp- σ^2))
65+
partype(::LogNormal{T}) where {T<:Real} = T
6566

6667
function var(d::LogNormal)
6768
(μ, σ) = params(d)

src/univariate/continuous/semicircle.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Semicircle(r::Integer; check_args::Bool=true) = Semicircle(float(r); check_args=
3434
@distr_support Semicircle -d.r +d.r
3535

3636
params(d::Semicircle) = (d.r,)
37+
partype(::Semicircle{T}) where {T<:Real} = T
3738

3839
mean(d::Semicircle) = zero(d.r)
3940
var(d::Semicircle) = d.r^2 / 4

src/univariate/discrete/discreteuniform.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ span(d::DiscreteUniform) = d.b - d.a + 1
4242
probval(d::DiscreteUniform) = d.pv
4343
params(d::DiscreteUniform) = (d.a, d.b)
4444

45+
partype(::DiscreteUniform) = Int
46+
4547
### Show
4648

4749
show(io::IO, d::DiscreteUniform) = show(io, d, (:a, :b))

src/univariate/discrete/hypergeometric.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ end
3838

3939
@distr_support Hypergeometric max(d.n - d.nf, 0) min(d.ns, d.n)
4040

41+
partype(::Hypergeometric) = Int
4142

4243
### Parameters
4344

src/univariate/locationscale.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ end
5151

5252
function AffineDistribution::T, σ::T, ρ::UnivariateDistribution; check_args::Bool=true) where {T<:Real}
5353
@check_args AffineDistribution (σ, !iszero(σ))
54-
_T = promote_type(eltype(ρ), T)
55-
return AffineDistribution{_T}(_T(μ), _T(σ), ρ)
54+
# μ and σ act on both random numbers and parameter-like quantities like mean
55+
# hence do not promote: but take care in eltype and partype
56+
return AffineDistribution{T}(μ, σ, ρ)
5657
end
5758

5859
function AffineDistribution::Real, σ::Real, ρ::UnivariateDistribution; check_args::Bool=true)
@@ -71,7 +72,7 @@ end
7172
const ContinuousAffineDistribution{T<:Real,D<:ContinuousUnivariateDistribution} = AffineDistribution{T,Continuous,D}
7273
const DiscreteAffineDistribution{T<:Real,D<:DiscreteUnivariateDistribution} = AffineDistribution{T,Discrete,D}
7374

74-
Base.eltype(::Type{<:AffineDistribution{T}}) where T = T
75+
Base.eltype(::Type{<:AffineDistribution{T,S,D}}) where {T,S,D} = promote_type(eltype(D), T)
7576

7677
minimum(d::AffineDistribution) =
7778
d.σ > 0 ? d.μ + d.σ * minimum(d.ρ) : d.μ + d.σ * maximum(d.ρ)
@@ -102,7 +103,7 @@ Base.convert(::Type{AffineDistribution{T}}, d::AffineDistribution{T}) where {T<:
102103
location(d::AffineDistribution) = d.μ
103104
scale(d::AffineDistribution) = d.σ
104105
params(d::AffineDistribution) = (d.μ,d.σ,d.ρ)
105-
partype(::AffineDistribution{T}) where {T} = T
106+
partype(d::AffineDistribution{T}) where {T} = promote_type(partype(d.ρ), T)
106107

107108
#### Statistics
108109

test/univariate/locationscale.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function test_location_scale(
117117
rand!(rng, dtest, r)
118118
end
119119
@test mean(r) mean(dref) atol=0.02
120-
@test std(r) std(dref) atol=0.01
120+
@test std(r) std(dref) atol=0.02
121121
@test cf(dtest, -0.1) cf(dref,-0.1)
122122

123123
if dref isa ContinuousDistribution

test/univariates.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ function verify_and_test(D::Union{Type,Function}, d::UnivariateDistribution, dct
6262
# test various constructors for promotion, all-Integer args, etc.
6363
pars = params(d)
6464

65+
# verify parameter type
66+
# truncated parameters may be nothing
67+
@test partype(d) === mapfoldl(
68+
typeof, (S, T) -> T <: Distribution ? promote_type(S, partype(T)) : (T <: Nothing ? S : promote_type(S, eltype(T))), pars; init = Union{})
69+
6570
# promotion constructor:
6671
float_pars = map(x -> isa(x, AbstractFloat), pars)
6772
if length(pars) > 1 && sum(float_pars) > 1 && !isa(D, typeof(truncated))

test/utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ r = RealInterval(1.5, 4.0)
1515

1616
# special cases
1717
@test partype(Kolmogorov()) == Float64
18-
@test partype(Hypergeometric(2, 2, 2)) == Float64
19-
@test partype(DiscreteUniform(0, 4)) == Float64
18+
@test partype(Hypergeometric(2, 2, 2)) == Int
19+
@test partype(Hypergeometric(2.0, 2, 2)) == Int
20+
@test partype(DiscreteUniform(0, 4)) == Int
21+
@test partype(DiscreteUniform(0.0, 4)) == Int
2022

2123
A = rand(1:10, 5, 5)
2224
B = rand(Float32, 4)

0 commit comments

Comments
 (0)