Skip to content

Commit cd6e5c0

Browse files
authored
Add conjugate option to copytri! (#413)
1 parent dd5c428 commit cd6e5c0

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

src/host/linalg.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,18 @@ end
3232

3333
## copy upper triangle to lower and vice versa
3434

35-
function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar) where T
35+
function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, conjugate::Bool=false) where T
3636
n = LinearAlgebra.checksquare(A)
37-
if uplo == 'U'
37+
if uplo == 'U' && conjugate
38+
gpu_call(A) do ctx, _A
39+
I = @cartesianidx _A
40+
i, j = Tuple(I)
41+
if j > i
42+
_A[j,i] = conj(_A[i,j])
43+
end
44+
return
45+
end
46+
elseif uplo == 'U' && !conjugate
3847
gpu_call(A) do ctx, _A
3948
I = @cartesianidx _A
4049
i, j = Tuple(I)
@@ -43,7 +52,16 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar) whe
4352
end
4453
return
4554
end
46-
elseif uplo == 'L'
55+
elseif uplo == 'L' && conjugate
56+
gpu_call(A) do ctx, _A
57+
I = @cartesianidx _A
58+
i, j = Tuple(I)
59+
if j > i
60+
_A[i,j] = conj(_A[j,i])
61+
end
62+
return
63+
end
64+
elseif uplo == 'L' && !conjugate
4765
gpu_call(A) do ctx, _A
4866
I = @cartesianidx _A
4967
i, j = Tuple(I)

test/testsuite/linalg.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@
6363

6464
@testset "triangular" begin
6565
@testset "copytri!" begin
66-
@testset for uplo in ('U', 'L')
67-
@test compare(x -> LinearAlgebra.copytri!(x, uplo), AT, rand(Float32, 128, 128))
66+
@testset for eltya in (Float32, Float64, ComplexF32, ComplexF64), uplo in ('U', 'L'), conjugate in (true, false)
67+
n = 128
68+
areal = randn(n,n)/2
69+
aimg = randn(n,n)/2
70+
if !(eltya in eltypes)
71+
continue
72+
end
73+
a = convert(Matrix{eltya}, eltya <: Complex ? complex.(areal, aimg) : areal)
74+
@test compare(x -> LinearAlgebra.copytri!(x, uplo, conjugate), AT, a)
6875
end
6976
end
7077

0 commit comments

Comments
 (0)