diff --git a/base/math.jl b/base/math.jl index 16a8a547e8de1..d3b633a3fadba 100644 --- a/base/math.jl +++ b/base/math.jl @@ -23,7 +23,7 @@ import .Base: log, exp, sin, cos, tan, sinh, cosh, tanh, asin, using .Base: sign_mask, exponent_mask, exponent_one, exponent_half, uinttype, significand_mask, significand_bits, exponent_bits, exponent_bias, - exponent_max, exponent_raw_max, clamp, clamp! + exponent_max, exponent_raw_max, clamp, clamp!, reduce_empty_iter using Core.Intrinsics: sqrt_llvm @@ -94,7 +94,7 @@ julia> evalpoly(2, (1, 2, 3)) function evalpoly(x, p::Tuple) if @generated N = length(p.parameters::Core.SimpleVector) - ex = :(p[end]) + ex = :(oftype(one(x) * p[end], p[end])) for i in N-1:-1:1 ex = :(muladd(x, $ex, p[$i])) end @@ -103,20 +103,23 @@ function evalpoly(x, p::Tuple) _evalpoly(x, p) end end +evalpoly(x, ::Tuple{}) = zero(one(x)) # dimensionless zero, i.e. 0 * x^0 evalpoly(x, p::AbstractVector) = _evalpoly(x, p) function _evalpoly(x, p) Base.require_one_based_indexing(p) N = length(p) - ex = p[end] + p0 = iszero(N) ? reduce_empty_iter(+, p) : @inbounds p[N] + s = oftype(one(x) * p0, p0) for i in N-1:-1:1 - ex = muladd(x, ex, p[i]) + s = muladd(x, s, @inbounds p[i]) end - ex + return s end -function evalpoly(z::Complex, p::Tuple) +# Goertzel-like algorithm from Knuth, TAOCP vol. 2, section 4.6.4: +function evalpoly(z::Complex, p::Tuple{Any, Any, Vararg}) if @generated N = length(p.parameters) a = :(p[end]) @@ -141,17 +144,14 @@ function evalpoly(z::Complex, p::Tuple) _evalpoly(z, p) end end -evalpoly(z::Complex, p::Tuple{<:Any}) = p[1] - - -evalpoly(z::Complex, p::AbstractVector) = _evalpoly(z, p) function _evalpoly(z::Complex, p) Base.require_one_based_indexing(p) - length(p) == 1 && return p[1] N = length(p) - a = p[end] - b = p[end-1] + p0 = iszero(N) ? reduce_empty_iter(+, p) : @inbounds p[N] + N <= 1 && return oftype(one(z) * p0, p0) + a = p0 + @inbounds b = p[N-1] x = real(z) y = imag(z) @@ -160,7 +160,7 @@ function _evalpoly(z::Complex, p) for i in N-2:-1:1 ai = a a = muladd(r, ai, b) - b = muladd(-s, ai, p[i]) + b = muladd(-s, ai, @inbounds p[i]) end ai = a muladd(ai, z, b) diff --git a/test/math.jl b/test/math.jl index 7070fe63ba931..d82e98ac8db9b 100644 --- a/test/math.jl +++ b/test/math.jl @@ -707,6 +707,10 @@ end c = 3 @test @evalpoly(c, a0, a1) == 7 @test @evalpoly(1, 2) == 2 + + isdefined(Main, :Furlongs) || @eval Main include("testhelpers/Furlongs.jl") + using .Main.Furlongs + @test @evalpoly(Furlong(2)) === evalpoly(Furlong(2), ()) === evalpoly(Furlong(2), Int[]) === 0 end @testset "evalpoly real" begin @@ -715,6 +719,13 @@ end @test evalpoly(x, (p1, p2, p3)) == evpm @test evalpoly(x, [p1, p2, p3]) == evpm end + @test evalpoly(1.0f0, ()) === 0.0f0 # issue #56699 + @test @inferred(evalpoly(1.0f0, Int[])) === 0.0f0 # issue #56699 + @test_throws MethodError evalpoly(1.0f0, []) + @test @inferred(evalpoly(1.0f0, [2])) === 2.0f0 # type-stability + + # different @generated branches should return same type: + @test evalpoly(3.0, (1,)) === Base.Math._evalpoly(3.0, (1,)) === 1.0 end @testset "evalpoly complex" begin @@ -726,6 +737,14 @@ end end @test evalpoly(1+im, (2,)) == 2 @test evalpoly(1+im, [2,]) == 2 + @test evalpoly(1.0f0+im, ()) === 0.0f0+0im # issue #56699 + @test @inferred(evalpoly(1.0f0+im, Int[])) === 0.0f0+0im # issue #56699 + @test_throws MethodError evalpoly(1.0f0, []) + @test @inferred(evalpoly(1.0f0+im, [2])) === 2.0f0+0im # type-stability + + # different @generated branches should return same type: + @test evalpoly(3.0+0im, (1,)) === Base.Math._evalpoly(3.0+0im, (1,)) === 1.0+0im + @test evalpoly(3.0+0im, (1,2)) === Base.Math._evalpoly(3.0+0im, (1,2)) === 7.0+0im end @testset "cis" begin