Skip to content

Commit a1df96d

Browse files
authored
Avoid StackOverflowError with erf* functions (#353)
* Avoid StackOverflowError with erf* functions * Rearrange BigFloat implementations
1 parent c0442c2 commit a1df96d

File tree

2 files changed

+107
-92
lines changed

2 files changed

+107
-92
lines changed

src/erf.jl

Lines changed: 89 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,73 @@ using Base.Math: @horner
44
using Base.MPFR: ROUNDING_MODE
55

66
for f in (:erf, :erfc)
7+
internalf = Symbol(:_, f)
8+
libopenlibmf = QuoteNode(f)
9+
libopenlibmf0 = QuoteNode(Symbol(f, :f))
10+
openspecfunf = QuoteNode(Symbol(:Faddeeva_, f))
11+
mpfrf = QuoteNode(Symbol(:mpfr_, f))
712
@eval begin
8-
($f)(x::Float64) = ccall(($(string(f)),libopenlibm), Float64, (Float64,), x)
9-
($f)(x::Float32) = ccall(($(string(f,"f")),libopenlibm), Float32, (Float32,), x)
10-
($f)(x::Real) = ($f)(float(x))
11-
($f)(a::Float16) = Float16($f(Float32(a)))
12-
($f)(a::Complex{Float16}) = Complex{Float16}($f(Complex{Float32}(a)))
13-
function ($f)(x::BigFloat)
13+
$f(x::Number) = $internalf(float(x))
14+
15+
$internalf(x::Float64) = ccall(($libopenlibmf, libopenlibm), Float64, (Float64,), x)
16+
$internalf(x::Float32) = ccall(($libopenlibmf0, libopenlibm), Float32, (Float32,), x)
17+
$internalf(x::Float16) = Float16($internalf(Float32(x)))
18+
19+
$internalf(z::Complex{Float64}) = Complex{Float64}(ccall(($openspecfunf, libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), z, zero(Float64)))
20+
$internalf(z::Complex{Float32}) = Complex{Float32}(ccall(($openspecfunf, libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), Complex{Float64}(z), Float64(eps(Float32))))
21+
$internalf(z::Complex{Float16}) = Complex{Float16}($internalf(Complex{Float32}(z)))
22+
23+
function $internalf(x::BigFloat)
1424
z = BigFloat()
15-
ccall(($(string(:mpfr_,f)), :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Int32), z, x, ROUNDING_MODE[])
25+
ccall(($mpfrf, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Int32), z, x, ROUNDING_MODE[])
1626
return z
1727
end
18-
($f)(x::AbstractFloat) = error("not implemented for ", typeof(x))
1928
end
2029
end
2130

22-
for f in (:erf, :erfc, :erfcx, :erfi, :Dawson)
23-
fname = (f === :Dawson) ? :dawson : f
31+
for f in (:erfcx, :erfi, :dawson)
32+
internalf = Symbol(:_, f)
33+
openspecfunfsym = Symbol(:Faddeeva_, f === :dawson ? :Dawson : f)
34+
openspecfunfF64 = QuoteNode(Symbol(openspecfunfsym, :_re))
35+
openspecfunfCF64 = QuoteNode(openspecfunfsym)
2436
@eval begin
25-
($fname)(z::Complex{Float64}) = Complex{Float64}(ccall(($(string("Faddeeva_",f)),libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), z, zero(Float64)))
26-
($fname)(z::Complex{Float32}) = Complex{Float32}(ccall(($(string("Faddeeva_",f)),libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), Complex{Float64}(z), Float64(eps(Float32))))
37+
$f(x::Number) = $internalf(float(x))
38+
39+
$internalf(x::Float64) = ccall(($openspecfunfF64, libopenspecfun), Float64, (Float64,), x)
40+
$internalf(x::Float32) = Float32($internalf(Float64(x)))
41+
$internalf(x::Float16) = Float16($internalf(Float64(x)))
2742

28-
($fname)(z::Complex) = ($fname)(float(z))
29-
($fname)(z::Complex{<:AbstractFloat}) = throw(MethodError($fname,(z,)))
43+
$internalf(z::Complex{Float64}) = Complex{Float64}(ccall(($openspecfunfCF64, libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), z, zero(Float64)))
44+
$internalf(z::Complex{Float32}) = Complex{Float32}(ccall(($openspecfunfCF64, libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), Complex{Float64}(z), Float64(eps(Float32))))
45+
$internalf(z::Complex{Float16}) = Complex{Float16}($internalf(Complex{Float32}(z)))
3046
end
3147
end
3248

33-
for f in (:erfcx, :erfi, :Dawson)
34-
fname = (f === :Dawson) ? :dawson : f
35-
@eval begin
36-
($fname)(x::Float64) = ccall(($(string("Faddeeva_",f,"_re")),libopenspecfun), Float64, (Float64,), x)
37-
($fname)(x::Float32) = Float32(ccall(($(string("Faddeeva_",f,"_re")),libopenspecfun), Float64, (Float64,), Float64(x)))
38-
39-
($fname)(x::Real) = ($fname)(float(x))
40-
($fname)(x::AbstractFloat) = throw(MethodError($fname,(x,)))
49+
# MPFR has an open TODO item for this function
50+
# until then, we use [DLMF 7.12.1](https://dlmf.nist.gov/7.12.1) for the tail
51+
function _erfcx(x::BigFloat)
52+
if x <= (Clong == Int32 ? 0x1p15 : 0x1p30)
53+
# any larger gives internal overflow
54+
return exp(x^2)*erfc(x)
55+
elseif !isfinite(x)
56+
return 1/x
57+
else
58+
# asymptotic series
59+
# starts to diverge at iteration i = 2^30 or 2^60
60+
# final term will be < Γ(2*i+1)/(2^i * Γ(i+1)) / (2^(i+1))
61+
# so good to (lgamma(2*i+1) - lgamma(i+1))/log(2) - 2*i - 1
62+
# ≈ 3.07e10 or 6.75e19 bits
63+
# which is larger than the memory of the respective machines
64+
ϵ = eps(BigFloat)/4
65+
v = 1/(2*x*x)
66+
k = 1
67+
s = w = -k*v
68+
while abs(w) > ϵ
69+
k += 2
70+
w *= -k*v
71+
s += w
72+
end
73+
return (1+s)/(x*sqrtπ)
4174
end
4275
end
4376

@@ -204,7 +237,9 @@ Using the rational approximants tabulated in:
204237
> <http://www.jstor.org/stable/2005402>
205238
combined with Newton iterations for `BigFloat`.
206239
"""
207-
function erfinv(x::Float64)
240+
erfinv(x::Real) = _erfinv(float(x))
241+
242+
function _erfinv(x::Float64)
208243
a = abs(x)
209244
if a >= 1.0
210245
if x == 1.0
@@ -272,7 +307,7 @@ function erfinv(x::Float64)
272307
end
273308
end
274309

275-
function erfinv(x::Float32)
310+
function _erfinv(x::Float32)
276311
a = abs(x)
277312
if a >= 1.0f0
278313
if x == 1.0f0
@@ -315,7 +350,25 @@ function erfinv(x::Float32)
315350
end
316351
end
317352

318-
erfinv(x::Union{Integer,Rational}) = erfinv(float(x))
353+
function _erfinv(y::BigFloat)
354+
xfloat = erfinv(Float64(y))
355+
if isfinite(xfloat)
356+
x = BigFloat(xfloat)
357+
else
358+
# Float64 overflowed, use asymptotic estimate instead
359+
# from erfc(x) ≈ exp(-x²)/x√π ≈ y ⟹ -log(yπ) ≈ x² + log(x) ≈ x²
360+
x = copysign(sqrt(-log((1-abs(y))*sqrtπ)), y)
361+
isfinite(x) || return x
362+
end
363+
sqrtπhalf = sqrtπ * big(0.5)
364+
tol = 2eps(abs(x))
365+
while true # Newton iterations
366+
Δx = sqrtπhalf * (erf(x) - y) * exp(x^2)
367+
x -= Δx
368+
abs(Δx) < tol && break
369+
end
370+
return x
371+
end
319372

320373
@doc raw"""
321374
erfcinv(x)
@@ -341,7 +394,9 @@ Using the rational approximants tabulated in:
341394
> <http://www.jstor.org/stable/2005402>
342395
combined with Newton iterations for `BigFloat`.
343396
"""
344-
function erfcinv(y::Float64)
397+
erfcinv(x::Real) = _erfcinv(float(x))
398+
399+
function _erfcinv(y::Float64)
345400
if y > 0.0625
346401
return erfinv(1.0 - y)
347402
elseif y <= 0.0
@@ -393,7 +448,7 @@ function erfcinv(y::Float64)
393448
end
394449
end
395450

396-
function erfcinv(y::Float32)
451+
function _erfcinv(y::Float32)
397452
if y > 0.0625f0
398453
return erfinv(1.0f0 - y)
399454
elseif y <= 0.0f0
@@ -415,27 +470,7 @@ function erfcinv(y::Float32)
415470
end
416471
end
417472

418-
function erfinv(y::BigFloat)
419-
xfloat = erfinv(Float64(y))
420-
if isfinite(xfloat)
421-
x = BigFloat(xfloat)
422-
else
423-
# Float64 overflowed, use asymptotic estimate instead
424-
# from erfc(x) ≈ exp(-x²)/x√π ≈ y ⟹ -log(yπ) ≈ x² + log(x) ≈ x²
425-
x = copysign(sqrt(-log((1-abs(y))*sqrtπ)), y)
426-
isfinite(x) || return x
427-
end
428-
sqrtπhalf = sqrtπ * big(0.5)
429-
tol = 2eps(abs(x))
430-
while true # Newton iterations
431-
Δx = sqrtπhalf * (erf(x) - y) * exp(x^2)
432-
x -= Δx
433-
abs(Δx) < tol && break
434-
end
435-
return x
436-
end
437-
438-
function erfcinv(y::BigFloat)
473+
function _erfcinv(y::BigFloat)
439474
yfloat = Float64(y)
440475
xfloat = erfcinv(yfloat)
441476
if isfinite(xfloat)
@@ -461,36 +496,6 @@ function erfcinv(y::BigFloat)
461496
return x
462497
end
463498

464-
erfcinv(x::Union{Integer,Rational}) = erfcinv(float(x))
465-
466-
# MPFR has an open TODO item for this function
467-
# until then, we use [DLMF 7.12.1](https://dlmf.nist.gov/7.12.1) for the tail
468-
function erfcx(x::BigFloat)
469-
if x <= (Clong == Int32 ? 0x1p15 : 0x1p30)
470-
# any larger gives internal overflow
471-
return exp(x^2)*erfc(x)
472-
elseif !isfinite(x)
473-
return 1/x
474-
else
475-
# asymptotic series
476-
# starts to diverge at iteration i = 2^30 or 2^60
477-
# final term will be < Γ(2*i+1)/(2^i * Γ(i+1)) / (2^(i+1))
478-
# so good to (lgamma(2*i+1) - lgamma(i+1))/log(2) - 2*i - 1
479-
# ≈ 3.07e10 or 6.75e19 bits
480-
# which is larger than the memory of the respective machines
481-
ϵ = eps(BigFloat)/4
482-
v = 1/(2*x*x)
483-
k = 1
484-
s = w = -k*v
485-
while abs(w) > ϵ
486-
k += 2
487-
w *= -k*v
488-
s += w
489-
end
490-
return (1+s)/(x*sqrtπ)
491-
end
492-
end
493-
494499
@doc raw"""
495500
logerfc(x)
496501
@@ -511,7 +516,9 @@ See also: [`erfcx(x)`](@ref erfcx).
511516
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
512517
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
513518
"""
514-
function logerfc(x::Union{Float32, Float64, BigFloat})
519+
logerfc(x::Real) = _logerfc(float(x))
520+
521+
function _logerfc(x::Union{Float32, Float64, BigFloat})
515522
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
516523
if x > 0.0
517524
return log(erfcx(x)) - x^2
@@ -520,9 +527,6 @@ function logerfc(x::Union{Float32, Float64, BigFloat})
520527
end
521528
end
522529

523-
logerfc(x::Real) = logerfc(float(x))
524-
logerfc(x::AbstractFloat) = throw(MethodError(logerfc, x))
525-
526530
@doc raw"""
527531
logerfcx(x)
528532
@@ -543,7 +547,9 @@ See also: [`erfcx(x)`](@ref erfcx).
543547
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
544548
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
545549
"""
546-
function logerfcx(x::Union{Float32, Float64, BigFloat})
550+
logerfcx(x::Real) = _logerfcx(float(x))
551+
552+
function _logerfcx(x::Union{Float32, Float64, BigFloat})
547553
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
548554
if x < 0.0
549555
return log(erfc(x)) + x^2
@@ -552,9 +558,6 @@ function logerfcx(x::Union{Float32, Float64, BigFloat})
552558
end
553559
end
554560

555-
logerfcx(x::Real) = logerfcx(float(x))
556-
logerfcx(x::AbstractFloat) = throw(MethodError(logerfcx, x))
557-
558561
@doc raw"""
559562
logerf(x, y)
560563

test/erf.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
@test erfc(Float32(1)) 0.15729920705028513066 rtol=2*eps(Float32)
99
@test erfc(Float64(1)) 0.15729920705028513066 rtol=2*eps(Float64)
1010

11-
@test_throws MethodError erfcx(Float16(1))
11+
@test erfcx(Float16(1)) 0.42758357615580700442 rtol=2*eps(Float16)
1212
@test erfcx(Float32(1)) 0.42758357615580700442 rtol=2*eps(Float32)
1313
@test erfcx(Float64(1)) 0.42758357615580700442 rtol=2*eps(Float64)
1414

@@ -38,7 +38,7 @@
3838
@test logerfcx(Float32(1000)) -7.48012072190621214066734919080 rtol=2eps(Float32)
3939
@test logerfcx(Float64(1000)) -7.48012072190621214066734919080 rtol=2eps(Float64)
4040

41-
@test_throws MethodError erfi(Float16(1))
41+
@test erfi(Float16(1)) 1.6504257587975428760 rtol=2*eps(Float16)
4242
@test erfi(Float32(1)) 1.6504257587975428760 rtol=2*eps(Float32)
4343
@test erfi(Float64(1)) 1.6504257587975428760 rtol=2*eps(Float64)
4444

@@ -52,7 +52,7 @@
5252
@test erfcinv(Float32(0.15729920705028513066)) 1 rtol=2*eps(Float32)
5353
@test erfcinv(Float64(0.15729920705028513066)) 1 rtol=2*eps(Float64)
5454

55-
@test_throws MethodError dawson(Float16(1))
55+
@test dawson(Float16(1)) 0.53807950691276841914 rtol=2*eps(Float16)
5656
@test dawson(Float32(1)) 0.53807950691276841914 rtol=2*eps(Float32)
5757
@test dawson(Float64(1)) 0.53807950691276841914 rtol=2*eps(Float64)
5858
end
@@ -66,19 +66,19 @@
6666
@test erfc(ComplexF32(1+2im)) 1.5366435657785650340+5.0491437034470346695im
6767
@test erfc(ComplexF64(1+2im)) 1.5366435657785650340+5.0491437034470346695im
6868

69-
@test_throws MethodError erfcx(ComplexF16(1))
69+
@test erfcx(ComplexF16(1+2im)) 0.14023958136627794370-0.22221344017989910261im
7070
@test erfcx(ComplexF32(1+2im)) 0.14023958136627794370-0.22221344017989910261im
7171
@test erfcx(ComplexF64(1+2im)) 0.14023958136627794370-0.22221344017989910261im
7272

73-
@test_throws MethodError erfi(ComplexF16(1))
73+
@test erfi(ComplexF16(1+2im)) -0.011259006028815025076+1.0036063427256517509im
7474
@test erfi(ComplexF32(1+2im)) -0.011259006028815025076+1.0036063427256517509im
7575
@test erfi(ComplexF64(1+2im)) -0.011259006028815025076+1.0036063427256517509im
7676

7777
@test_throws MethodError erfinv(Complex(1))
7878

7979
@test_throws MethodError erfcinv(Complex(1))
8080

81-
@test_throws MethodError dawson(ComplexF16(1))
81+
@test dawson(ComplexF16(1+2im)) -13.388927316482919244-11.828715103889593303im
8282
@test dawson(ComplexF32(1+2im)) -13.388927316482919244-11.828715103889593303im
8383
@test dawson(ComplexF64(1+2im)) -13.388927316482919244-11.828715103889593303im
8484
end
@@ -116,6 +116,18 @@
116116
end
117117
end
118118

119+
@testset "Other float types" begin
120+
struct NotAFloat <: AbstractFloat end
121+
122+
@test_throws MethodError erf(NotAFloat())
123+
@test_throws MethodError erfc(NotAFloat())
124+
@test_throws MethodError erfcx(NotAFloat())
125+
@test_throws MethodError erfi(NotAFloat())
126+
@test_throws MethodError erfinv(NotAFloat())
127+
@test_throws MethodError erfcinv(NotAFloat())
128+
@test_throws MethodError dawson(NotAFloat())
129+
end
130+
119131
@testset "inverse" begin
120132
for elty in [Float32,Float64]
121133
for x in exp10.(range(-200, stop=-0.01, length=50))

0 commit comments

Comments
 (0)