Skip to content

Commit aceebea

Browse files
authored
Change in-place exp to out-of-place in matrix trig functions (#56242)
This makes the functions work for arbitrary matrix types that support `exp`, but not necessarily the in-place `exp!`. For example, the following works after this: ```julia julia> m = SMatrix{2,2}(1:4); julia> cos(m) 2×2 SMatrix{2, 2, Float64, 4} with indices SOneTo(2)×SOneTo(2): 0.855423 -0.166315 -0.110876 0.689109 ``` There's a slight performance improvement as well because we don't compute `im*A` and `-im*A` separately, but we negate the first to obtain the second. ```julia julia> A = rand(ComplexF64,100,100); julia> @Btime sin($A); 2.796 ms (48 allocations: 1.84 MiB) # nightly v"1.12.0-DEV.1571" 2.304 ms (48 allocations: 1.84 MiB) # this PR ```
1 parent 3cf803e commit aceebea

File tree

2 files changed

+69
-28
lines changed

2 files changed

+69
-28
lines changed

src/dense.jl

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,11 @@ function inv(A::StridedMatrix{T}) where T
10391039
return Ai
10401040
end
10411041

1042+
# helper function to perform a broadcast in-place if the destination is strided
1043+
# otherwise, this performs an out-of-place broadcast
1044+
@inline _broadcast!!(f, dest::StridedArray, args...) = broadcast!(f, dest, args...)
1045+
@inline _broadcast!!(f, dest, args...) = broadcast(f, args...)
1046+
10421047
"""
10431048
cos(A::AbstractMatrix)
10441049
@@ -1061,19 +1066,22 @@ function cos(A::AbstractMatrix{<:Real})
10611066
elseif issymmetric(A)
10621067
return copytri!(parent(cos(Symmetric(A))), 'U')
10631068
end
1064-
T = complex(float(eltype(A)))
1065-
return real(exp!(T.(im .* A)))
1069+
M = im .* float.(A)
1070+
return real(exp_maybe_inplace(M))
10661071
end
10671072
function cos(A::AbstractMatrix{<:Complex})
10681073
if isdiag(A)
10691074
return applydiagonal(cos, A)
10701075
elseif ishermitian(A)
10711076
return copytri!(parent(cos(Hermitian(A))), 'U', true)
10721077
end
1073-
T = complex(float(eltype(A)))
1074-
X = exp!(T.(im .* A))
1075-
@. X = (X + $exp!(T(-im*A))) / 2
1076-
return X
1078+
M = im .* float.(A)
1079+
N = -M
1080+
X = exp_maybe_inplace(M)
1081+
Y = exp_maybe_inplace(N)
1082+
# Compute (X + Y)/2 and return the result.
1083+
# Compute the result in-place if X is strided
1084+
_broadcast!!((x,y) -> (x + y)/2, X, X, Y)
10771085
end
10781086

10791087
"""
@@ -1098,23 +1106,22 @@ function sin(A::AbstractMatrix{<:Real})
10981106
elseif issymmetric(A)
10991107
return copytri!(parent(sin(Symmetric(A))), 'U')
11001108
end
1101-
T = complex(float(eltype(A)))
1102-
return imag(exp!(T.(im .* A)))
1109+
M = im .* float.(A)
1110+
return imag(exp_maybe_inplace(M))
11031111
end
11041112
function sin(A::AbstractMatrix{<:Complex})
11051113
if isdiag(A)
11061114
return applydiagonal(sin, A)
11071115
elseif ishermitian(A)
11081116
return copytri!(parent(sin(Hermitian(A))), 'U', true)
11091117
end
1110-
T = complex(float(eltype(A)))
1111-
X = exp!(T.(im .* A))
1112-
Y = exp!(T.(.-im .* A))
1113-
@inbounds for i in eachindex(X, Y)
1114-
x, y = X[i]/2, Y[i]/2
1115-
X[i] = Complex(imag(x)-imag(y), real(y)-real(x))
1116-
end
1117-
return X
1118+
M = im .* float.(A)
1119+
Mneg = -M
1120+
X = exp_maybe_inplace(M)
1121+
Y = exp_maybe_inplace(Mneg)
1122+
# Compute (X - Y)/2im and return the result.
1123+
# Compute the result in-place if X is strided
1124+
_broadcast!!((x,y) -> (x - y)/2im, X, X, Y)
11181125
end
11191126

11201127
"""
@@ -1144,8 +1151,8 @@ function sincos(A::AbstractMatrix{<:Real})
11441151
cosA = copytri!(parent(symcosA), 'U')
11451152
return sinA, cosA
11461153
end
1147-
T = complex(float(eltype(A)))
1148-
c, s = reim(exp!(T.(im .* A)))
1154+
M = im .* float.(A)
1155+
c, s = reim(exp_maybe_inplace(M))
11491156
return s, c
11501157
end
11511158
function sincos(A::AbstractMatrix{<:Complex})
@@ -1155,16 +1162,26 @@ function sincos(A::AbstractMatrix{<:Complex})
11551162
cosA = copytri!(parent(hermcosA), 'U', true)
11561163
return sinA, cosA
11571164
end
1158-
T = complex(float(eltype(A)))
1159-
X = exp!(T.(im .* A))
1160-
Y = exp!(T.(.-im .* A))
1165+
M = im .* float.(A)
1166+
Mneg = -M
1167+
X = exp_maybe_inplace(M)
1168+
Y = exp_maybe_inplace(Mneg)
1169+
_sincos(X, Y)
1170+
end
1171+
function _sincos(X::StridedMatrix, Y::StridedMatrix)
11611172
@inbounds for i in eachindex(X, Y)
11621173
x, y = X[i]/2, Y[i]/2
11631174
X[i] = Complex(imag(x)-imag(y), real(y)-real(x))
11641175
Y[i] = x+y
11651176
end
11661177
return X, Y
11671178
end
1179+
function _sincos(X, Y)
1180+
T = eltype(X)
1181+
S = T(0.5)*im .* (Y .- X)
1182+
C = T(0.5) .* (X .+ Y)
1183+
S, C
1184+
end
11681185

11691186
"""
11701187
tan(A::AbstractMatrix)
@@ -1205,8 +1222,9 @@ function cosh(A::AbstractMatrix)
12051222
return copytri!(parent(cosh(Hermitian(A))), 'U', true)
12061223
end
12071224
X = exp(A)
1208-
@. X = (X + $exp!(float(-A))) / 2
1209-
return X
1225+
negA = @. float(-A)
1226+
Y = exp_maybe_inplace(negA)
1227+
_broadcast!!((x,y) -> (x + y)/2, X, X, Y)
12101228
end
12111229

12121230
"""
@@ -1221,8 +1239,9 @@ function sinh(A::AbstractMatrix)
12211239
return copytri!(parent(sinh(Hermitian(A))), 'U', true)
12221240
end
12231241
X = exp(A)
1224-
@. X = (X - $exp!(float(-A))) / 2
1225-
return X
1242+
negA = @. float(-A)
1243+
Y = exp_maybe_inplace(negA)
1244+
_broadcast!!((x,y) -> (x - y)/2, X, X, Y)
12261245
end
12271246

12281247
"""
@@ -1237,15 +1256,20 @@ function tanh(A::AbstractMatrix)
12371256
return copytri!(parent(tanh(Hermitian(A))), 'U', true)
12381257
end
12391258
X = exp(A)
1240-
Y = exp!(float.(.-A))
1259+
negA = @. float(-A)
1260+
Y = exp_maybe_inplace(negA)
1261+
X′, Y′ = _subadd!!(X, Y)
1262+
return X′ / Y′
1263+
end
1264+
function _subadd!!(X::StridedMatrix, Y::StridedMatrix)
12411265
@inbounds for i in eachindex(X, Y)
12421266
x, y = X[i], Y[i]
12431267
X[i] = x - y
12441268
Y[i] = x + y
12451269
end
1246-
X /= Y
1247-
return X
1270+
return X, Y
12481271
end
1272+
_subadd!!(X, Y) = X - Y, X + Y
12491273

12501274
"""
12511275
acos(A::AbstractMatrix)

test/dense.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ module TestDense
55
using Test, LinearAlgebra, Random
66
using LinearAlgebra: BlasComplex, BlasFloat, BlasReal
77

8+
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
9+
isdefined(Main, :FillArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "FillArrays.jl"))
10+
import Main.FillArrays
11+
812
@testset "Check that non-floats are correctly promoted" begin
913
@test [1 0 0; 0 1 0]\[1,1] [1;1;0]
1014
end
@@ -1302,4 +1306,17 @@ end
13021306
end
13031307
end
13041308

1309+
@testset "trig functions for non-strided" begin
1310+
@testset for T in (Float32,ComplexF32)
1311+
A = FillArrays.Fill(T(0.1), 4, 4) # all.(<(1), eigvals(A)) for atanh
1312+
M = Matrix(A)
1313+
@testset for f in (sin,cos,tan,sincos,sinh,cosh,tanh)
1314+
@test f(A) == f(M)
1315+
end
1316+
@testset for f in (asin,acos,atan,asinh,acosh,atanh)
1317+
@test f(A) == f(M)
1318+
end
1319+
end
1320+
end
1321+
13051322
end # module TestDense

0 commit comments

Comments
 (0)