Skip to content

Commit 2f64a3b

Browse files
stevengjStefanKarpinski
authored andcommitted
Faster, more correct complex^complex (#24570)
1 parent df2616e commit 2f64a3b

File tree

3 files changed

+146
-80
lines changed

3 files changed

+146
-80
lines changed

base/complex.jl

Lines changed: 74 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -648,98 +648,89 @@ end
648648
exp10(z::Complex) = exp10(float(z))
649649

650650
# _cpow helper function to avoid method ambiguity with ^(::Complex,::Real)
651-
function _cpow(z::Complex{T}, p::Union{T,Complex{T}})::Complex{T} where T<:AbstractFloat
652-
if p == 2 #square
653-
zr, zi = reim(z)
654-
x = (zr-zi)*(zr+zi)
655-
y = 2*zr*zi
656-
if isnan(x)
657-
if isinf(y)
658-
x = copysign(zero(T),zr)
659-
elseif isinf(zi)
660-
x = convert(T,-Inf)
661-
elseif isinf(zr)
662-
x = convert(T,Inf)
651+
function _cpow(z::Union{T,Complex{T}}, p::Union{T,Complex{T}}) where {T<:AbstractFloat}
652+
if isreal(p)
653+
pᵣ = real(p)
654+
if isinteger(pᵣ) && abs(pᵣ) < typemax(Int32)
655+
# |p| < typemax(Int32) serves two purposes: it prevents overflow
656+
# when converting p to Int, and it also turns out to be roughly
657+
# the crossover point for exp(p*log(z)) or similar to be faster.
658+
if iszero(pᵣ) # fix signs of imaginary part for z^0
659+
zer = flipsign(copysign(zero(T),pᵣ), imag(z))
660+
return Complex(one(T), zer)
663661
end
664-
elseif isnan(y) && isinf(x)
665-
y = copysign(zero(T), y)
666-
end
667-
Complex(x,y)
668-
elseif z!=0
669-
if p!=0 && isinteger(p)
670-
rp = real(p)
671-
if rp < 0
672-
return power_by_squaring(inv(z), convert(Integer, -rp))
673-
else
674-
return power_by_squaring(z, convert(Integer, rp))
675-
end
676-
end
677-
exp(p*log(z))
678-
elseif p!=0 #0^p
679-
zero(z) #CHECK SIGNS
680-
else #0^0
681-
zer = copysign(zero(T),real(p))*copysign(zero(T),imag(z))
682-
Complex(one(T), zer)
683-
end
684-
end
685-
function _cpow(z::Complex{T}, p::Union{T,Complex{T}}) where T<:Real
686-
if isinteger(p)
687-
rp = real(p)
688-
if rp < 0
689-
return power_by_squaring(inv(float(z)), convert(Integer, -rp))
690-
else
691-
return power_by_squaring(float(z), convert(Integer, rp))
692-
end
693-
end
694-
pr, pim = reim(p)
695-
zr, zi = reim(z)
696-
r = abs(z)
697-
rp = r^pr
698-
theta = atan2(zi, zr)
699-
ntheta = pr*theta
700-
if pim != 0 && r != 0
701-
rp = rp*exp(-pim*theta)
702-
ntheta = ntheta + pim*log(r)
703-
end
704-
sinntheta, cosntheta = sincos(ntheta)
705-
re, im = rp*cosntheta, rp*sinntheta
706-
if isinf(rp)
707-
if isnan(re)
708-
re = copysign(zero(re), cosntheta)
709-
end
710-
if isnan(im)
711-
im = copysign(zero(im), sinntheta)
712-
end
713-
end
714-
715-
# apply some corrections to force known zeros
716-
if pim == 0
717-
if isinteger(pr)
718-
if zi == 0
719-
im = copysign(zero(im), im)
720-
elseif zr == 0
721-
if isinteger(0.5*pr) # pr is even
722-
im = copysign(zero(im), im)
662+
ip = convert(Int, pᵣ)
663+
if isreal(z)
664+
zᵣ = real(z)
665+
if ip < 0
666+
iszero(z) && return Complex(T(NaN),T(NaN))
667+
re = power_by_squaring(inv(zᵣ), -ip)
668+
im = -imag(z)
723669
else
724-
re = copysign(zero(re), re)
670+
re = power_by_squaring(zᵣ, ip)
671+
im = imag(z)
725672
end
673+
# slightly tricky to get the correct sign of zero imag. part
674+
return Complex(re, ifelse(iseven(ip) & signbit(zᵣ), -im, im))
675+
else
676+
return ip < 0 ? power_by_squaring(inv(z), -ip) : power_by_squaring(z, ip)
726677
end
727-
else
728-
dr = pr*2
729-
if isinteger(dr) && zi == 0
730-
if zr < 0
731-
re = copysign(zero(re), re)
678+
elseif isreal(z)
679+
# (note: if both z and p are complex with ±0.0 imaginary parts,
680+
# the sign of the ±0.0 imaginary part of the result is ambiguous)
681+
if iszero(real(z))
682+
return pᵣ > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im
683+
elseif real(z) > 0
684+
return Complex(real(z)^pᵣ, z isa Real ? ifelse(real(z) < 1, -imag(p), imag(p)) : flipsign(imag(z), pᵣ))
685+
else
686+
zᵣ = real(z)
687+
rᵖ = (-zᵣ)^pᵣ
688+
if isfinite(pᵣ)
689+
# figuring out the sign of 0.0 when p is a complex number
690+
# with zero imaginary part and integer/2 real part could be
691+
# improved here, but it's not clear if it's worth it…
692+
return rᵖ * complex(cospi(pᵣ), flipsign(sinpi(pᵣ),imag(z)))
732693
else
733-
im = copysign(zero(im), im)
694+
iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0
695+
return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input
734696
end
735697
end
698+
else
699+
rᵖ = abs(z)^pᵣ
700+
ϕ = pᵣ*angle(z)
736701
end
702+
elseif isreal(z)
703+
iszero(z) && return real(p) > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im
704+
zᵣ = real(z)
705+
pᵣ, pᵢ = reim(p)
706+
if zᵣ > 0
707+
rᵖ = zᵣ^pᵣ
708+
ϕ = pᵢ*log(zᵣ)
709+
else
710+
r = -zᵣ
711+
θ = copysign(T(π),imag(z))
712+
rᵖ = r^pᵣ * exp(-pᵢ*θ)
713+
ϕ = pᵣ*θ + pᵢ*log(r)
714+
end
715+
else
716+
pᵣ, pᵢ = reim(p)
717+
r = abs(z)
718+
θ = angle(z)
719+
rᵖ = r^pᵣ * exp(-pᵢ*θ)
720+
ϕ = pᵣ*θ + pᵢ*log(r)
737721
end
738722

739-
Complex(re, im)
723+
if isfinite(ϕ)
724+
return rᵖ * cis(ϕ)
725+
else
726+
iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0
727+
return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input
728+
end
740729
end
730+
_cpow(z, p) = _cpow(float(z), float(p))
741731
^(z::Complex{T}, p::Complex{T}) where T<:Real = _cpow(z, p)
742732
^(z::Complex{T}, p::T) where T<:Real = _cpow(z, p)
733+
^(z::T, p::Complex{T}) where T<:Real = _cpow(z, p)
743734

744735
^(z::Complex, n::Bool) = n ? z : one(z)
745736
^(z::Complex, n::Integer) = z^Complex(n)
@@ -755,6 +746,10 @@ function ^(z::Complex{T}, p::S) where {T<:Real,S<:Real}
755746
P = promote_type(T,S)
756747
return Complex{P}(z) ^ P(p)
757748
end
749+
function ^(z::T, p::Complex{S}) where {T<:Real,S<:Real}
750+
P = promote_type(T,S)
751+
return P(z) ^ Complex{P}(p)
752+
end
758753

759754
function sin(z::Complex{T}) where T
760755
F = float(T)

base/mathconstants.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ catalan = 0.9159655941772...
8282
catalan
8383

8484
# loop over types to prevent ambiguities for ^(::Number, x)
85-
for T in (AbstractIrrational, Rational, Integer, Number)
85+
for T in (AbstractIrrational, Rational, Integer, Number, Complex)
8686
Base.:^(::Irrational{:ℯ}, x::T) = exp(x)
8787
end
8888
Base.literal_pow(::typeof(^), ::Irrational{:ℯ}, ::Val{p}) where {p} = exp(p)

test/complex.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,3 +998,74 @@ end
998998
end
999999
@test (2+0im)^(-21//10) === (2//1+0im)^(-21//10) === 2^-2.1 - 0.0im
10001000
end
1001+
1002+
@testset "more cpow" begin
1003+
# for testing signs of zeros, it is useful to convert ±0.0 to ±1e-15
1004+
zero2small(r::Real) = iszero(r) ? copysign(1e-15, r) : r
1005+
zero2small(z::Complex) = complex(zero2small(real(z)), zero2small(imag(z)))
1006+
(x::Real, y::Real) = x*y == 0 ? abs(x) < 1e-8 && abs(y) < 1e-8 && signbit(x)==signbit(y) : isfinite(x) ? x y : isequal(x, y)
1007+
(x::Complex, y::Complex) = real(x) real(y) && imag(x) imag(y)
1008+
(x,y) = isequal(x,y)
1009+
1010+
# test z^p for positive/negative/zero real and imaginary parts of z and p:
1011+
v=(-2.7,-3.0,-2.0,-0.0,+0.0,2.0,3.0,2.7)
1012+
for zr=v, zi=v, pr=v, pi=v
1013+
z = complex(zr,zi)
1014+
p = iszero(pi) ? pr : complex(pr,pi)
1015+
if isinteger(p)
1016+
c = zero2small(z)^Integer(pr)
1017+
else
1018+
c = exp(zero2small(p) * log(zero2small(z)))
1019+
end
1020+
if !iszero(z*p) # z==0 or p==0 is tricky, check it separately
1021+
@test z^p c
1022+
if isreal(p)
1023+
@test z^(p + 1e-15im) z^(p - 1e-15im) c
1024+
if isinteger(p)
1025+
@test isequal(z^Integer(pr), z^p)
1026+
end
1027+
elseif (zr != 0 || !signbit(zr)) && (zi != 0 || !signbit(zi))
1028+
@test isequal((Complex{Int}(z*10)//10)^p, z^p)
1029+
end
1030+
end
1031+
end
1032+
1033+
@test 2 ^ (0.3 + 0.0im) === 2.0 ^ (0.3 + 0.0im) === conj(2.0 ^ (0.3 - 0.0im)) 2.0 ^ (0.3 + 1e-15im)
1034+
@test 0.2 ^ (0.3 + 0.0im) === conj(0.2 ^ (0.3 - 0.0im)) 0.2 ^ (0.3 + 1e-15im)
1035+
@test (0.0 - 0.0im)^2.0 === (0.0 - 0.0im)^2 === (0.0 - 0.0im)^1.1 === (0.0 - 0.0im) ^ (1.1 + 2.3im) === 0.0 - 0.0im
1036+
@test (0.0 - 0.0im)^-2.0 (0.0 - 0.0im)^-2 (0.0 - 0.0im)^-1.1 (0.0 - 0.0im) ^ (-1.1 + 2.3im) NaN + NaN*im
1037+
@test (1.0+0.0)^(1.2+0.7im) === 1.0 + 0.0im
1038+
@test (-1.0+0.0)^(2.0+0.7im) exp(-0.7π)
1039+
@test (-4.0+0.0im)^1.5 === (-4.0)^(1.5+0.0im) === (-4)^(1.5+0.0im) === (-4)^(3//2+0im) === 0.0 - 8.0im
1040+
1041+
# issue #24515:
1042+
@test (Inf + Inf*im)^2.0 (Inf + Inf*im)^2 NaN + Inf*im
1043+
@test (0+0im)^-3.0 (0+0im)^-3 NaN + NaN*im
1044+
@test (1.0+0.0im)^1e300 === 1.0 + 0.0im
1045+
@test Inf^(-Inf + 0.0im) == (Inf + 0.0im)^(-Inf - 0.0im) == (Inf - 0.0im)^(-Inf - 0.0im) == (Inf - 0.0im)^-Inf == 0
1046+
1047+
# NaN propagation
1048+
@test (0 + NaN*im)^1 (0 + NaN*im)^1.0 (0 + NaN*im)^(1.0+0im) 0.0 + NaN*im
1049+
@test (0 + NaN*im)^2 (0 + NaN*im)^2.0 (0 + NaN*im)^(2.0+0im) NaN + NaN*im
1050+
@test (NaN + 0im)^2.0 (NaN + 0im)^(2.0+0im) (2+0im)^NaN NaN + 0im
1051+
@test (NaN + 0im)^2.5 NaN^(2.5+0im) (NaN + NaN*im)^2.5 (-2+0im)^NaN (2+0im)^(1+NaN*im) NaN + NaN*im
1052+
1053+
# more Inf cases:
1054+
@test (Inf + 0im)^Inf === Inf^(Inf + 0im) === (Inf + 0im)^(Inf + 0im) == Inf + 0im
1055+
@test (-Inf + 0im)^(0.7 + 0im) === (-Inf + 1im)^(0.7 + 0im) === conj((-Inf - 1im)^(0.7 + 0im)) === -Inf + Inf*im
1056+
@test (-Inf + 0.0im) ^ 3.1 === conj((-Inf - 0.0im) ^ 3.1) === -Inf - Inf*im
1057+
@test (3.0+0.0im)^(Inf + 1im) === (3.0-0.0im)^(Inf + 1im) === conj((3.0+0.0im)^(Inf - 1im)) === Inf + Inf*im
1058+
1059+
# The following cases should arguably give Inf + Inf*im, but currently
1060+
# give partial NaNs instead. Marking as broken for now (since Julia 0.4 at least),
1061+
# in the hope that someday we can fix these corner cases. (Python gets them wrong too.)
1062+
@test_broken (Inf + 1im)^3 === (Inf + 1im)^3.0 === (Inf + 1im)^(3+0im) === Inf + Inf*im
1063+
@test_broken (Inf + 1im)^3.1 === (Inf + 1im)^(3.1+0im) === Inf + Inf*im
1064+
1065+
# cases where phase angle is non-finite yield NaN + NaN*im:
1066+
@test NaN + NaN*im Inf ^ (2 + 3im) (Inf + 1im) ^ (2 + 3im) (Inf*im) ^ (2 + 3im)
1067+
3^(Inf*im) (-3)^(Inf + 0im) (-3)^(Inf + 1im) (3+1im)^Inf
1068+
(3+1im)^(Inf + 1im) (1e200+1e-200im)^Inf (1e200+1e-200im)^(Inf+1im)
1069+
1070+
@test @inferred(2.0^(3.0+0im)) === @inferred((2.0+0im)^(3.0+0im)) === @inferred((2.0+0im)^3.0) === 8.0+0.0im
1071+
end

0 commit comments

Comments
 (0)