Skip to content

Commit 51c5cd6

Browse files
authored
Merge pull request #377 from devmotion/dw/refactor_gamma_inc_inv
Refactor `gamma_inc_inv`
2 parents d829f0a + b8f3057 commit 51c5cd6

File tree

2 files changed

+41
-54
lines changed

2 files changed

+41
-54
lines changed

src/gamma_inc.jl

Lines changed: 34 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ function gamma_inc_fsum(a::Float64, x::Float64)
665665
end
666666

667667
"""
668-
gamma_inc_inv_psmall(a,p)
668+
gamma_inc_inv_psmall(a, logr)
669669
670670
Compute `x0` - initial approximation when `p` is small.
671671
Here we invert the series in Eqn (2.20) in the paper and write the inversion problem as:
@@ -675,8 +675,7 @@ x = r\\left[1 + a\\sum_{k=1}^{\\infty}\\frac{(-x)^{n}}{(a+n)n!}\\right]^{-1/a},
675675
where ``r = (p\\Gamma(1+a))^{1/a}``
676676
Inverting this relation we obtain ``x = r + \\sum_{k=2}^{\\infty}c_{k}r^{k}``.
677677
"""
678-
function gamma_inc_inv_psmall(a::Float64, p::Float64)
679-
logr = (1.0/a)*(log(p) + logabsgamma(a + 1.0)[1])
678+
function gamma_inc_inv_psmall(a::Float64, logr::Float64)
680679
r = exp(logr)
681680
ap1 = a + 1.0
682681
ap1² = ap1*ap1
@@ -726,19 +725,7 @@ function gamma_inc_inv_qsmall(a::Float64, q::Float64)
726725
end
727726

728727
"""
729-
gamma_inc_inv_asmall(a,p,q,pcase)
730-
731-
Compute `x0` - initial approximation when `a` is small.
732-
Here the solution `x` of ``P(a,x)=p`` satisfies ``x_{l} < x < x_{u}``
733-
where ``x_{l} = (p\\Gamma(a+1))^{1/a}`` and ``x_{u} = -\\log{(1 - p\\Gamma(a+1))}``, and is used as starting value for Newton iteration.
734-
"""
735-
function gamma_inc_inv_asmall(a::Float64, p::Float64, q::Float64, pcase::Bool)
736-
logp = (pcase) ? log(p) : log1p(-q)
737-
return exp((1.0/a)*(logp +loggamma1p(a)))
738-
end
739-
740-
"""
741-
gamma_inc_inv_alarge(a,porq,s)
728+
gamma_inc_inv_alarge(a, minpq, pcase)
742729
743730
Compute `x0` - initial approximation when `a` is large.
744731
The inversion problem is rewritten as :
@@ -753,9 +740,10 @@ and it is possible to expand:
753740
which is calculated by coeff1, coeff2 and coeff3 functions below.
754741
This returns a tuple `(x0,fp)`, where `fp` is computed since it's an approximation for the coefficient after inverting the original power series.
755742
"""
756-
function gamma_inc_inv_alarge(a::Float64, porq::Float64, s::Integer)
757-
r = erfcinv(2*porq)
758-
eta = s*r/sqrt(a*0.5)
743+
function gamma_inc_inv_alarge(a::Float64, minpq::Float64, pcase::Bool)
744+
r = erfcinv(2*minpq)
745+
s = r/sqrt(a*0.5)
746+
eta = pcase ? -s : s
759747
eta += (coeff1(eta) + (coeff2(eta) + coeff3(eta)/a)/a)/a
760748
x0 = a*lambdaeta(eta)
761749
fp = -sqrt(inv2π*a)*exp(-0.5*a*eta*eta)/gammax(a)
@@ -919,45 +907,47 @@ External links: [DLMF](https://dlmf.nist.gov/8.2.4), [Wikipedia](https://en.wiki
919907
920908
See also: [`gamma_inc(a,x,ind)`](@ref SpecialFunctions.gamma_inc).
921909
"""
922-
gamma_inc_inv(a::Real, p::Real, q::Real) = _gamma_inc_inv(promote(float(a), float(p), float(q))...)
923-
924-
function _gamma_inc_inv(a::Float64, p::Float64, q::Float64)
910+
function gamma_inc_inv(a::Real, p::Real, q::Real)
911+
return _gamma_inc_inv(map(float, promote(a, p, q))...)
912+
end
925913

914+
# `gamma inc_inv` ensures that arguments of `_gamma_inc_inv` are
915+
# floating point numbers of the same type
916+
function _gamma_inc_inv(a::T, p::T, q::T) where {T<:Real}
926917
if p + q != 1
927918
throw(ArgumentError("p + q must equal one but is $(p + q)"))
928919
end
929920

930921
if iszero(p)
931-
return 0.0
922+
return zero(T)
932923
elseif iszero(q)
933-
return Inf
924+
return T(Inf)
934925
end
935926

936-
if p < 0.5
937-
pcase = true
938-
porq = p
939-
s = -1
940-
else
941-
pcase = false
942-
porq = q
943-
s = 1
944-
end
927+
pcase = p < 0.5
928+
minpq = pcase ? p : q
929+
return __gamma_inc_inv(a, minpq, pcase)
930+
end
931+
932+
function __gamma_inc_inv(a::Float64, minpq::Float64, pcase::Bool)
945933
haseta = false
946934

947-
logr = (1.0/a)*(log(p) + logabsgamma(a + 1.0)[1])
935+
logp = pcase ? log(minpq) : log1p(-minpq)
936+
loggamma1pa = a <= 1.0 ? loggamma1p(a) : loggamma(a + 1.0)
937+
logr = (logp + loggamma1pa) / a
948938
if logr < log(0.2*(1 + a)) #small value of p
949-
x0 = gamma_inc_inv_psmall(a, p)
950-
elseif ((q < min(0.02, exp(-1.5*a)/gamma(a))) && (a < 10)) #small q
951-
x0 = gamma_inc_inv_qsmall(a, q)
952-
elseif abs(porq - 0.5) < 1.0e-05
939+
x0 = gamma_inc_inv_psmall(a, logr)
940+
elseif !pcase && minpq < min(0.02, exp(-1.5*a)/gamma(a)) && a < 10 #small q
941+
x0 = gamma_inc_inv_qsmall(a, minpq)
942+
elseif abs(minpq - 0.5) < 1.0e-05
953943
x0 = a - 1.0/3.0 + (8.0/405.0 + 184.0/25515.0/a)/a
954944
elseif abs(a - 1.0) < 1.0e-4
955-
x0 = pcase ? -log1p(-p) : -log(q)
945+
x0 = pcase ? -log1p(-minpq) : -log(minpq)
956946
elseif a < 1.0 # small value of a
957-
x0 = gamma_inc_inv_asmall(a, p, q, pcase)
947+
x0 = exp(logr)
958948
else #large a
959949
haseta = true
960-
x0, fp = gamma_inc_inv_alarge(a, porq, s)
950+
x0, fp = gamma_inc_inv_alarge(a, minpq, pcase)
961951
end
962952

963953
t = 1
@@ -981,7 +971,7 @@ function _gamma_inc_inv(a::Float64, p::Float64, q::Float64)
981971

982972
px, qx = gamma_inc(a, x, 0)
983973

984-
ck1 = pcase ? -r*(px - p) : r*(qx - q)
974+
ck1 = pcase ? -r*(px - minpq) : r*(qx - minpq)
985975
if a > 0.05
986976
ck2 = (x - a + 1.0)/(2.0*x)
987977

@@ -1014,16 +1004,8 @@ function _gamma_inc_inv(a::Float64, p::Float64, q::Float64)
10141004
return x
10151005
end
10161006

1017-
function _gamma_inc_inv(a::T, p::T, q::T) where {T <: Union{Float16, Float32}}
1018-
if p + q != one(T)
1019-
throw(ArgumentError("p + q must equal one but was $(p + q)"))
1020-
end
1021-
p64, q64 = if p < q
1022-
(Float64(p), 1 - Float64(p))
1023-
else
1024-
(1 - Float64(q), Float64(q))
1025-
end
1026-
return T(_gamma_inc_inv(Float64(a), p64, q64))
1007+
function __gamma_inc_inv(a::T, minpq::T, pcase::Bool) where {T<:Union{Float16,Float32}}
1008+
return T(__gamma_inc_inv(Float64(a), Float64(minpq), pcase))
10271009
end
10281010

10291011
# like promote(x,y), but don't complexify real values

test/gamma_inc.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,13 @@ end
170170
@testset "Low precision with Float64(p) + Float64(q) != 1" for T in (Float16, Float32)
171171
@test gamma_inc(T(1.0), gamma_inc_inv(T(1.0), T(0.1), T(0.9)))[1]::T T(0.1)
172172
@test gamma_inc(T(1.0), gamma_inc_inv(T(1.0), T(0.9), T(0.1)))[2]::T T(0.1)
173-
@test_throws ArgumentError("p + q must equal one but was 1.02") gamma_inc_inv(T(1.0), T(0.1), T(0.92))
174-
@test_throws ArgumentError("p + q must equal one but was 1.02") gamma_inc_inv(T(1.0), T(0.92), T(0.1))
173+
@test_throws ArgumentError("p + q must equal one but is 1.02") gamma_inc_inv(T(1.0), T(0.1), T(0.92))
174+
@test_throws ArgumentError("p + q must equal one but is 1.02") gamma_inc_inv(T(1.0), T(0.92), T(0.1))
175+
end
176+
177+
@testset "Promotion of arguments" begin
178+
@test @inferred(gamma_inc_inv(1//2, 0.3f0, 0.7f0)) isa Float32
179+
@test @inferred(gamma_inc_inv(1, 0.2f0, 0.8f0)) isa Float32
175180
end
176181
end
177182

0 commit comments

Comments
 (0)