Skip to content

Commit 124915f

Browse files
Add erfinv and erfcinv for Float16 and generalize logerfc and logerfcx (#372)
* Add `erfinv` and `erfcinv` for `Float16` * Generalize `logerfc` and `logerfcx` * Add tests * Update version number * Update test/erf.jl * Fix test * Simplify branch for `abs(x) >= 1` in `_erfinv` * Fix incomplete function * Simplify implementation * Apply suggested alternative --------- Co-authored-by: Viral B. Shah <ViralBShah@users.noreply.github.com>
1 parent e08ff8d commit 124915f

File tree

3 files changed

+104
-80
lines changed

3 files changed

+104
-80
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "SpecialFunctions"
22
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
3-
version = "2.3.1"
3+
version = "2.4.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/erf.jl

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,10 @@ erfinv(x::Real) = _erfinv(float(x))
253253

254254
function _erfinv(x::Float64)
255255
a = abs(x)
256-
if a >= 1.0
257-
if x == 1.0
258-
return Inf
259-
elseif x == -1.0
260-
return -Inf
261-
end
256+
if a > 1.0
262257
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))
258+
elseif a == 1.0
259+
return copysign(Inf, x)
263260
elseif a <= 0.75 # Table 17 in Blair et al.
264261
t = x*x - 0.5625
265262
return x * @horner(t, 0.16030_49558_44066_229311e2,
@@ -321,13 +318,10 @@ end
321318

322319
function _erfinv(x::Float32)
323320
a = abs(x)
324-
if a >= 1.0f0
325-
if x == 1.0f0
326-
return Inf32
327-
elseif x == -1.0f0
328-
return -Inf32
329-
end
321+
if a > 1f0
330322
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))
323+
elseif a == 1f0
324+
return copysign(Inf32, x)
331325
elseif a <= 0.75f0 # Table 10 in Blair et al.
332326
t = x*x - 0.5625f0
333327
return x * @horner(t, -0.13095_99674_22f2,
@@ -362,6 +356,42 @@ function _erfinv(x::Float32)
362356
end
363357
end
364358

359+
function _erfinv(x::Float16)
360+
a = abs(x)
361+
if a > Float16(1)
362+
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))
363+
elseif a == Float16(1)
364+
return copysign(Inf16, x)
365+
else
366+
# Perform calculations with `Float32`
367+
x32 = Float32(x)
368+
a32 = Float32(a)
369+
if a32 <= 0.75f0
370+
# Simpler and more accurate alternative to Table 7 in Blair et al.
371+
# Ref: https://github.com/JuliaMath/SpecialFunctions.jl/pull/372#discussion_r1592832735
372+
t = muladd(-6.73815f1, x32, 1f0) / muladd(-4.18798f0, x32, 4.54263f0)
373+
y = copysign(muladd(0.88622695f0, x32, t), x32)
374+
elseif a32 <= 0.9375f0 # Table 26 in Blair et al.
375+
t = x32^2 - 0.87890625f0
376+
y = x32 * @horner(t, 0.10178_950f1,
377+
-0.32827_601f1) /
378+
@horner(t, 0.72455_99f0,
379+
-0.33871_553f1,
380+
0.1f1)
381+
else
382+
# Simpler alternative to Table 47 in Blair et al.
383+
# because of the reduced accuracy requirement
384+
# (it turns out that this branch only covers 128 values).
385+
# Note that the use of log(1-x) rather than log1p is intentional since it will be
386+
# slightly faster and 1-x is exact.
387+
# Ref: https://github.com/JuliaMath/SpecialFunctions.jl/pull/372#discussion_r1592710586
388+
t = sqrt(-log(1-a32))
389+
y = copysign(@horner(t, -0.429159f0, 1.04868f0), x32)
390+
end
391+
return Float16(y)
392+
end
393+
end
394+
365395
function _erfinv(y::BigFloat)
366396
xfloat = erfinv(Float64(y))
367397
if isfinite(xfloat)
@@ -482,6 +512,25 @@ function _erfcinv(y::Float32)
482512
end
483513
end
484514

515+
function _erfcinv(y::Float16)
516+
if y > Float16(0.0625)
517+
return erfinv(Float16(1) - y)
518+
elseif y <= Float16(0)
519+
if y == Float16(0)
520+
return Inf16
521+
end
522+
throw(DomainError(y, "`y` must be nonnegative."))
523+
else # Table 47 in Blair et al.
524+
t = 1.0f0 / sqrt(-log(Float32(y)))
525+
x = @horner(t, 0.98650_088f0,
526+
0.92601_777f0) /
527+
(t * @horner(t, 0.98424_719f0,
528+
0.10074_7432f0,
529+
0.1f0))
530+
return Float16(x)
531+
end
532+
end
533+
485534
function _erfcinv(y::BigFloat)
486535
yfloat = Float64(y)
487536
xfloat = erfcinv(yfloat)
@@ -526,13 +575,9 @@ See also: [`erfcx(x)`](@ref erfcx).
526575
527576
# Implementation
528577
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
529-
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
530578
"""
531-
logerfc(x::Real) = _logerfc(float(x))
532-
533-
function _logerfc(x::Union{Float32, Float64, BigFloat})
534-
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
535-
if x > 0.0
579+
function logerfc(x::Real)
580+
if x > zero(x)
536581
return log(erfcx(x)) - x^2
537582
else
538583
return log(erfc(x))
@@ -557,13 +602,9 @@ See also: [`erfcx(x)`](@ref erfcx).
557602
558603
# Implementation
559604
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
560-
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
561605
"""
562-
logerfcx(x::Real) = _logerfcx(float(x))
563-
564-
function _logerfcx(x::Union{Float32, Float64, BigFloat})
565-
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
566-
if x < 0.0
606+
function logerfcx(x::Real)
607+
if x < zero(x)
567608
return log(erfc(x)) + x^2
568609
else
569610
return log(erfcx(x))

test/erf.jl

Lines changed: 38 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,48 @@
11
@testset "error functions" begin
22
@testset "real argument" begin
3-
@test erf(Float16(1)) 0.84270079294971486934 rtol=2*eps(Float16)
4-
@test erf(Float32(1)) 0.84270079294971486934 rtol=2*eps(Float32)
5-
@test erf(Float64(1)) 0.84270079294971486934 rtol=2*eps(Float64)
6-
7-
@test erfc(Float16(1)) 0.15729920705028513066 rtol=2*eps(Float16)
8-
@test erfc(Float32(1)) 0.15729920705028513066 rtol=2*eps(Float32)
9-
@test erfc(Float64(1)) 0.15729920705028513066 rtol=2*eps(Float64)
10-
11-
@test erfcx(Float16(1)) 0.42758357615580700442 rtol=2*eps(Float16)
12-
@test erfcx(Float32(1)) 0.42758357615580700442 rtol=2*eps(Float32)
13-
@test erfcx(Float64(1)) 0.42758357615580700442 rtol=2*eps(Float64)
14-
15-
@test_throws MethodError logerfc(Float16(1))
16-
@test_throws MethodError logerfc(Float16(-1))
17-
@test logerfc(Float32(-100)) 0.6931471805599453 rtol=2*eps(Float32)
18-
@test logerfc(Float64(-100)) 0.6931471805599453 rtol=2*eps(Float64)
19-
@test logerfc(Float32(1000)) -1.0000074801207219e6 rtol=2*eps(Float32)
20-
@test logerfc(Float64(1000)) -1.0000074801207219e6 rtol=2*eps(Float64)
21-
@test logerfc(1000) -1.0000074801207219e6 rtol=2*eps(Float32)
22-
@test logerfc(Float32(10000)) log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float32)
23-
@test logerfc(Float64(10000)) log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float64)
24-
25-
@test_throws MethodError logerfcx(Float16(1))
26-
@test_throws MethodError logerfcx(Float16(-1))
27-
@test iszero(logerfcx(0))
28-
@test logerfcx(Float32(1)) -0.849605509933248248576017509499 rtol=2eps(Float32)
29-
@test logerfcx(Float64(1)) -0.849605509933248248576017509499 rtol=2eps(Float32)
30-
@test logerfcx(Float32(-1)) 1.61123231767807049464268192445 rtol=2eps(Float32)
31-
@test logerfcx(Float64(-1)) 1.61123231767807049464268192445 rtol=2eps(Float32)
32-
@test logerfcx(Float32(-100)) 10000.6931471805599453094172321 rtol=2eps(Float32)
33-
@test logerfcx(Float64(-100)) 10000.6931471805599453094172321 rtol=2eps(Float64)
34-
@test logerfcx(Float32(100)) -5.17758512266433257046678208395 rtol=2eps(Float32)
35-
@test logerfcx(Float64(100)) -5.17758512266433257046678208395 rtol=2eps(Float64)
36-
@test logerfcx(Float32(-1000)) 1.00000069314718055994530941723e6 rtol=2eps(Float32)
37-
@test logerfcx(Float64(-1000)) 1.00000069314718055994530941723e6 rtol=2eps(Float64)
38-
@test logerfcx(Float32(1000)) -7.48012072190621214066734919080 rtol=2eps(Float32)
39-
@test logerfcx(Float64(1000)) -7.48012072190621214066734919080 rtol=2eps(Float64)
40-
41-
@test erfi(Float16(1)) 1.6504257587975428760 rtol=2*eps(Float16)
42-
@test erfi(Float32(1)) 1.6504257587975428760 rtol=2*eps(Float32)
43-
@test erfi(Float64(1)) 1.6504257587975428760 rtol=2*eps(Float64)
3+
for T in (Float16, Float32, Float64)
4+
@test @inferred(erf(T(1))) isa T
5+
@test erf(T(1)) T(0.84270079294971486934) rtol=2*eps(T)
446

45-
@test erfinv(Integer(0)) == 0 == erfinv(0//1)
46-
@test_throws MethodError erfinv(Float16(1))
47-
@test erfinv(Float32(0.84270079294971486934)) 1 rtol=2*eps(Float32)
48-
@test erfinv(Float64(0.84270079294971486934)) 1 rtol=2*eps(Float64)
7+
@test @inferred(erfc(T(1))) isa T
8+
@test erfc(T(1)) T(0.15729920705028513066) rtol=2*eps(T)
499

50-
@test erfcinv(Integer(1)) == 0 == erfcinv(1//1)
51-
@test_throws MethodError erfcinv(Float16(1))
52-
@test erfcinv(Float32(0.15729920705028513066)) 1 rtol=2*eps(Float32)
53-
@test erfcinv(Float64(0.15729920705028513066)) 1 rtol=2*eps(Float64)
10+
@test @inferred(erfcx(T(1))) isa T
11+
@test erfcx(T(1)) T(0.42758357615580700442) rtol=2*eps(T)
12+
13+
@test @inferred(logerfc(T(1))) isa T
14+
@test logerfc(T(-100)) T(0.6931471805599453) rtol=2*eps(T)
15+
@test logerfc(T(1000)) T(-1.0000074801207219e6) rtol=2*eps(T)
16+
@test logerfc(T(10000)) T(log(erfc(BigFloat(10000, precision=100)))) rtol=2*eps(T)
17+
18+
@test @inferred(logerfcx(T(1))) isa T
19+
@test logerfcx(T(1)) T(-0.849605509933248248576017509499) rtol=2eps(T)
20+
@test logerfcx(T(-1)) T(1.61123231767807049464268192445) rtol=2eps(T)
21+
@test logerfcx(T(-100)) T(10000.6931471805599453094172321) rtol=2eps(T)
22+
@test logerfcx(T(100)) T(-5.17758512266433257046678208395) rtol=2eps(T)
23+
@test logerfcx(T(-1000)) T(1.00000069314718055994530941723e6) rtol=2eps(T)
24+
@test logerfcx(T(1000)) T(-7.48012072190621214066734919080) rtol=2eps(T)
25+
26+
@test @inferred(erfi(T(1))) isa T
27+
@test erfi(T(1)) T(1.6504257587975428760) rtol=2*eps(T)
28+
29+
@test @inferred(erfinv(T(1))) isa T
30+
@test erfinv(T(0.84270079294971486934)) 1 rtol=2*eps(T)
5431

55-
@test dawson(Float16(1)) 0.53807950691276841914 rtol=2*eps(Float16)
56-
@test dawson(Float32(1)) 0.53807950691276841914 rtol=2*eps(Float32)
57-
@test dawson(Float64(1)) 0.53807950691276841914 rtol=2*eps(Float64)
32+
@test @inferred(erfcinv(T(1))) isa T
33+
@test erfcinv(T(0.15729920705028513066)) 1 rtol=2*eps(T)
5834

35+
@test @inferred(dawson(T(1))) isa T
36+
@test dawson(T(1)) T(0.53807950691276841914) rtol=2*eps(T)
37+
38+
@test @inferred(faddeeva(T(1))) isa Complex{T}
39+
@test faddeeva(T(1)) 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(T)
40+
end
41+
42+
@test logerfc(1000) -1.0000074801207219e6 rtol=2*eps(Float32)
43+
@test erfinv(Integer(0)) == 0 == erfinv(0//1)
44+
@test erfcinv(Integer(1)) == 0 == erfcinv(1//1)
5945
@test faddeeva(0) == faddeeva(0//1) == 1
60-
@test faddeeva(Float16(1)) 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float16)
61-
@test faddeeva(Float32(1)) 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float32)
62-
@test faddeeva(Float64(1)) 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float64)
6346
end
6447

6548
@testset "complex arguments" begin

0 commit comments

Comments
 (0)