From 61df653f0aa163e22eeaa0b7dab588aeba3497ae Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 28 Nov 2024 08:56:04 -0500 Subject: [PATCH 01/10] fix evalpoly type instability and 0-length case --- base/math.jl | 22 +++++++++++++--------- test/math.jl | 4 ++++ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/base/math.jl b/base/math.jl index 16a8a547e8de1..d012d214a1359 100644 --- a/base/math.jl +++ b/base/math.jl @@ -23,7 +23,7 @@ import .Base: log, exp, sin, cos, tan, sinh, cosh, tanh, asin, using .Base: sign_mask, exponent_mask, exponent_one, exponent_half, uinttype, significand_mask, significand_bits, exponent_bits, exponent_bias, - exponent_max, exponent_raw_max, clamp, clamp! + exponent_max, exponent_raw_max, clamp, clamp!, reduce_empty_iter using Core.Intrinsics: sqrt_llvm @@ -103,17 +103,20 @@ function evalpoly(x, p::Tuple) _evalpoly(x, p) end end +evalpoly(x, ::Tuple{}) = zero(one(x)) # dimensionless zero, i.e. 0 * x^0 evalpoly(x, p::AbstractVector) = _evalpoly(x, p) function _evalpoly(x, p) Base.require_one_based_indexing(p) N = length(p) - ex = p[end] + @inbounds p0 = N == 0 ? reduce_empty_iter(+, p) : p[N] + s = oftype(one(x) * p0, p0) + N <= 1 && return s for i in N-1:-1:1 - ex = muladd(x, ex, p[i]) + @inbounds s = muladd(x, s, p[i]) end - ex + return s end function evalpoly(z::Complex, p::Tuple) @@ -142,16 +145,17 @@ function evalpoly(z::Complex, p::Tuple) end end evalpoly(z::Complex, p::Tuple{<:Any}) = p[1] - +evalpoly(z::Complex, ::Tuple{}) = zero(one(z)) # dimensionless zero, i.e. 0 * z^0 evalpoly(z::Complex, p::AbstractVector) = _evalpoly(z, p) function _evalpoly(z::Complex, p) Base.require_one_based_indexing(p) - length(p) == 1 && return p[1] N = length(p) - a = p[end] - b = p[end-1] + @inbounds p0 = N == 0 ? reduce_empty_iter(+, p) : p[N] + N <= 1 && return oftype(one(z) * p0, p0) + a = p0 + @inbounds b = p[N-1] x = real(z) y = imag(z) @@ -160,7 +164,7 @@ function _evalpoly(z::Complex, p) for i in N-2:-1:1 ai = a a = muladd(r, ai, b) - b = muladd(-s, ai, p[i]) + @inbounds b = muladd(-s, ai, p[i]) end ai = a muladd(ai, z, b) diff --git a/test/math.jl b/test/math.jl index 7070fe63ba931..6842ac97bdb60 100644 --- a/test/math.jl +++ b/test/math.jl @@ -715,6 +715,8 @@ end @test evalpoly(x, (p1, p2, p3)) == evpm @test evalpoly(x, [p1, p2, p3]) == evpm end + @test evalpoly(1.0f0, ()) === 0.0f0 # issue #56699 + @test @inferred(evalpoly(1.0f0, [2])) === 2.0f0 # type-stability end @testset "evalpoly complex" begin @@ -726,6 +728,8 @@ end end @test evalpoly(1+im, (2,)) == 2 @test evalpoly(1+im, [2,]) == 2 + @test evalpoly(1.0f0+im, ()) === 0.0f0+0im # issue #56699 + @test @inferred(evalpoly(1.0f0+im, [2])) === 2.0f0+0im # type-stability end @testset "cis" begin From 1ee45506b6f366bbac98730f2a805943ac411ce3 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 28 Nov 2024 09:13:51 -0500 Subject: [PATCH 02/10] more evalpoly tests --- test/math.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/math.jl b/test/math.jl index 6842ac97bdb60..f15d1c3b0cc33 100644 --- a/test/math.jl +++ b/test/math.jl @@ -716,6 +716,8 @@ end @test evalpoly(x, [p1, p2, p3]) == evpm end @test evalpoly(1.0f0, ()) === 0.0f0 # issue #56699 + @test @inferred(evalpoly(1.0f0, Int[])) === 0.0f0 # issue #56699 + @test_throws MethodError evalpoly(1.0f0, []) @test @inferred(evalpoly(1.0f0, [2])) === 2.0f0 # type-stability end @@ -729,6 +731,8 @@ end @test evalpoly(1+im, (2,)) == 2 @test evalpoly(1+im, [2,]) == 2 @test evalpoly(1.0f0+im, ()) === 0.0f0+0im # issue #56699 + @test @inferred(evalpoly(1.0f0+im, Int[])) === 0.0f0+0im # issue #56699 + @test_throws MethodError evalpoly(1.0f0, []) @test @inferred(evalpoly(1.0f0+im, [2])) === 2.0f0+0im # type-stability end From 19950314c9567107faa7b0b1810bffaf1f741dd7 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 28 Nov 2024 20:33:21 -0500 Subject: [PATCH 03/10] rm unnecessary test --- base/math.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/base/math.jl b/base/math.jl index d012d214a1359..63ac6f94264ec 100644 --- a/base/math.jl +++ b/base/math.jl @@ -112,7 +112,6 @@ function _evalpoly(x, p) N = length(p) @inbounds p0 = N == 0 ? reduce_empty_iter(+, p) : p[N] s = oftype(one(x) * p0, p0) - N <= 1 && return s for i in N-1:-1:1 @inbounds s = muladd(x, s, p[i]) end From d22d8f109a1df0942274020b9376d9fadb7c7426 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Sat, 30 Nov 2024 10:19:33 -0500 Subject: [PATCH 04/10] rm extraneous methods --- base/math.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/base/math.jl b/base/math.jl index 63ac6f94264ec..5001f65ebd20c 100644 --- a/base/math.jl +++ b/base/math.jl @@ -118,7 +118,8 @@ function _evalpoly(x, p) return s end -function evalpoly(z::Complex, p::Tuple) +# Goertzel-like algorithm from Knuth, TAOCP vol. 2, section 4.6.4: +function evalpoly(z::Complex, p::Tuple{Any, Any, Vararg}) if @generated N = length(p.parameters) a = :(p[end]) @@ -143,10 +144,6 @@ function evalpoly(z::Complex, p::Tuple) _evalpoly(z, p) end end -evalpoly(z::Complex, p::Tuple{<:Any}) = p[1] -evalpoly(z::Complex, ::Tuple{}) = zero(one(z)) # dimensionless zero, i.e. 0 * z^0 - -evalpoly(z::Complex, p::AbstractVector) = _evalpoly(z, p) function _evalpoly(z::Complex, p) Base.require_one_based_indexing(p) From 7a1d9a2ca411dd04796bea429ca7d686ba0602d5 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 2 Jan 2025 10:01:18 -0500 Subject: [PATCH 05/10] added test for empty-coefficients case, with units --- test/math.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/math.jl b/test/math.jl index f15d1c3b0cc33..892af88c6b65f 100644 --- a/test/math.jl +++ b/test/math.jl @@ -707,6 +707,10 @@ end c = 3 @test @evalpoly(c, a0, a1) == 7 @test @evalpoly(1, 2) == 2 + + isdefined(Main, :Furlongs) || @eval Main include("testhelpers/Furlongs.jl") + using .Main.Furlongs + @test @evalpoly(Furlong(2)) === evalpoly(Furlong(2), ()) === evalpoly(Furlong(2), Int[]) === 0 end @testset "evalpoly real" begin From cf2a6b903d6643239d242352dced59b179fdf64a Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Fri, 29 Nov 2024 13:48:25 -0500 Subject: [PATCH 06/10] Update base/math.jl Co-authored-by: Neven Sajko --- base/math.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/math.jl b/base/math.jl index 5001f65ebd20c..30e7052876c3b 100644 --- a/base/math.jl +++ b/base/math.jl @@ -148,7 +148,7 @@ end function _evalpoly(z::Complex, p) Base.require_one_based_indexing(p) N = length(p) - @inbounds p0 = N == 0 ? reduce_empty_iter(+, p) : p[N] + p0 = iszero(N) ? reduce_empty_iter(+, p) : @inbounds p[N] N <= 1 && return oftype(one(z) * p0, p0) a = p0 @inbounds b = p[N-1] From 9df2b24b6a24d201c8f9449e97bf1ed3f1b1c606 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Fri, 29 Nov 2024 13:47:47 -0500 Subject: [PATCH 07/10] Update base/math.jl Co-authored-by: Neven Sajko --- base/math.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/math.jl b/base/math.jl index 30e7052876c3b..1fb4e840deca1 100644 --- a/base/math.jl +++ b/base/math.jl @@ -160,7 +160,7 @@ function _evalpoly(z::Complex, p) for i in N-2:-1:1 ai = a a = muladd(r, ai, b) - @inbounds b = muladd(-s, ai, p[i]) + b = muladd(-s, ai, @inbounds p[i]) end ai = a muladd(ai, z, b) From 1d42f4e8e74037e63e16891e41478fb59ee76273 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Fri, 29 Nov 2024 13:47:27 -0500 Subject: [PATCH 08/10] Update base/math.jl Co-authored-by: Neven Sajko --- base/math.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/math.jl b/base/math.jl index 1fb4e840deca1..4c79a24e167ad 100644 --- a/base/math.jl +++ b/base/math.jl @@ -113,7 +113,7 @@ function _evalpoly(x, p) @inbounds p0 = N == 0 ? reduce_empty_iter(+, p) : p[N] s = oftype(one(x) * p0, p0) for i in N-1:-1:1 - @inbounds s = muladd(x, s, p[i]) + s = muladd(x, s, @inbounds p[i]) end return s end From 05cfd26f576997d6a8981d60eee03219430a1d35 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Fri, 29 Nov 2024 13:47:19 -0500 Subject: [PATCH 09/10] Update base/math.jl Co-authored-by: Neven Sajko --- base/math.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/math.jl b/base/math.jl index 4c79a24e167ad..bf9c0696ff8a8 100644 --- a/base/math.jl +++ b/base/math.jl @@ -110,7 +110,7 @@ evalpoly(x, p::AbstractVector) = _evalpoly(x, p) function _evalpoly(x, p) Base.require_one_based_indexing(p) N = length(p) - @inbounds p0 = N == 0 ? reduce_empty_iter(+, p) : p[N] + p0 = iszero(N) ? reduce_empty_iter(+, p) : @inbounds p[N] s = oftype(one(x) * p0, p0) for i in N-1:-1:1 s = muladd(x, s, @inbounds p[i]) From f7b9e886145e3723d3b9a1bb93be0206f7e8efac Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 2 Jan 2025 10:27:27 -0500 Subject: [PATCH 10/10] ensure that different generated brnches return same type --- base/math.jl | 2 +- test/math.jl | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/base/math.jl b/base/math.jl index bf9c0696ff8a8..d3b633a3fadba 100644 --- a/base/math.jl +++ b/base/math.jl @@ -94,7 +94,7 @@ julia> evalpoly(2, (1, 2, 3)) function evalpoly(x, p::Tuple) if @generated N = length(p.parameters::Core.SimpleVector) - ex = :(p[end]) + ex = :(oftype(one(x) * p[end], p[end])) for i in N-1:-1:1 ex = :(muladd(x, $ex, p[$i])) end diff --git a/test/math.jl b/test/math.jl index 892af88c6b65f..d82e98ac8db9b 100644 --- a/test/math.jl +++ b/test/math.jl @@ -723,6 +723,9 @@ end @test @inferred(evalpoly(1.0f0, Int[])) === 0.0f0 # issue #56699 @test_throws MethodError evalpoly(1.0f0, []) @test @inferred(evalpoly(1.0f0, [2])) === 2.0f0 # type-stability + + # different @generated branches should return same type: + @test evalpoly(3.0, (1,)) === Base.Math._evalpoly(3.0, (1,)) === 1.0 end @testset "evalpoly complex" begin @@ -738,6 +741,10 @@ end @test @inferred(evalpoly(1.0f0+im, Int[])) === 0.0f0+0im # issue #56699 @test_throws MethodError evalpoly(1.0f0, []) @test @inferred(evalpoly(1.0f0+im, [2])) === 2.0f0+0im # type-stability + + # different @generated branches should return same type: + @test evalpoly(3.0+0im, (1,)) === Base.Math._evalpoly(3.0+0im, (1,)) === 1.0+0im + @test evalpoly(3.0+0im, (1,2)) === Base.Math._evalpoly(3.0+0im, (1,2)) === 7.0+0im end @testset "cis" begin