Skip to content

Commit 9daa9de

Browse files
committed
Do eigenpairs sorting in tests only
1 parent e730971 commit 9daa9de

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

lib/cusolver/linalg.jl

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,6 @@ function Base.:\(F::Union{LinearAlgebra.LAPACKFactorizations{<:Any,<:CuArray},
110110
return LinearAlgebra._cut_B(BB, 1:n)
111111
end
112112

113-
# Adapted from LinearAlgebra.sorteig!().
114-
# Warning: not very efficient, but works.
115-
eigsortby::Real) = λ
116-
eigsortby::Complex) = (real(λ),imag(λ))
117-
function sorteig!::AbstractVector, X::AbstractMatrix, sortby::Union{Function,Nothing}=eigsortby)
118-
if sortby !== nothing # && !issorted(λ, by=sortby)
119-
p = sortperm(λ; by=sortby)
120-
λ .= λ[p] # permute!(λ, p)
121-
X .= X[:, p] # Base.permutecols!!(X, p)
122-
end
123-
return λ, X
124-
end
125-
sorteig!::AbstractVector, sortby::Union{Function,Nothing}=eigsortby) = sortby === nothing ? λ : sort!(λ, by=sortby)
126-
127113
# eigen
128114

129115
function LinearAlgebra.eigen(A::Symmetric{T,<:CuMatrix}) where {T<:BlasReal}
@@ -142,7 +128,7 @@ function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasReal}
142128
A2 = copy(A)
143129
W, _, VR = Xgeev!('N', 'V', A2)
144130
C = Complex{T}
145-
U = CuMatrix{C}([1. 1.; im -im])
131+
U = CuMatrix{C}([1.0 1.0; im -im])
146132
VR = CuMatrix{C}(VR)
147133
h_W = collect(W)
148134
n = length(W)
@@ -151,16 +137,16 @@ function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasReal}
151137
if imag(h_W[j]) == 0
152138
j += 1
153139
else
154-
VR[:,j:j+1] .= VR[:,j:j+1] * U
140+
VR[:, j:(j + 1)] .= VR[:, j:(j + 1)] * U
155141
j += 2
156142
end
157143
end
158-
return Eigen(sorteig!(W, VR)...)
144+
return Eigen(W, VR)
159145
end
160146
function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasComplex}
161147
A2 = copy(A)
162148
r = Xgeev!('N', 'V', A2)
163-
return Eigen(sorteig!(r[1], r[3])...)
149+
return Eigen(r[1], r[3])
164150
end
165151

166152
# eigvals
@@ -179,11 +165,11 @@ end
179165

180166
function LinearAlgebra.eigvals(A::CuMatrix{T}) where {T <: BlasReal}
181167
A2 = copy(A)
182-
return sorteig!(Xgeev!('N', 'N', A2)[1])
168+
return Xgeev!('N', 'N', A2)[1]
183169
end
184170
function LinearAlgebra.eigvals(A::CuMatrix{T}) where {T <: BlasComplex}
185171
A2 = copy(A)
186-
return sorteig!(Xgeev!('N', 'N', A2)[1])
172+
return Xgeev!('N', 'N', A2)[1]
187173
end
188174

189175
# eigvecs

test/libraries/cusolver/dense.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@ p = 5
99
l = 13
1010
k = 1
1111

12+
# Adapted from LinearAlgebra.sorteig!().
13+
# Warning: not very efficient, but works.
14+
eigsortby::Real) = λ
15+
eigsortby::Complex) = (real(λ), imag(λ))
16+
function sorteig!::AbstractVector, X::AbstractMatrix, sortby::Union{Function, Nothing} = eigsortby)
17+
if sortby !== nothing # && !issorted(λ, by=sortby)
18+
p = sortperm(λ; by = sortby)
19+
λ .= λ[p] # permute!(λ, p)
20+
X .= X[:, p] # Base.permutecols!!(X, p)
21+
end
22+
return λ, X
23+
end
24+
sorteig!::AbstractVector, sortby::Union{Function, Nothing} = eigsortby) = sortby === nothing ? λ : sort!(λ, by = sortby)
25+
1226
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
1327
@testset "gesv!" begin
1428
@testset "irs_precision = AUTO" begin
@@ -342,6 +356,7 @@ k = 1
342356
d_A = CuArray(A)
343357
Eig = eigen(A)
344358
d_eig = eigen(d_A)
359+
sorteig!(d_eig.values, d_eig.vectors)
345360
@test Eig.values collect(d_eig.values)
346361
h_V = collect(d_eig.vectors)
347362
h_V⁻¹ = inv(h_V)
@@ -351,12 +366,15 @@ k = 1
351366
d_A = CuArray(A)
352367
W = eigvals(A)
353368
d_W = eigvals(d_A)
369+
sorteig!(d_W)
354370
@test W collect(d_W)
355371

356372
A = rand(elty,m,m)
357373
d_A = CuArray(A)
358374
V = eigvecs(A)
375+
d_W = eigvals(d_A)
359376
d_V = eigvecs(d_A)
377+
sorteig!(d_W, d_V)
360378
V⁻¹ = inv(V)
361379
@test abs.(V⁻¹*collect(d_V)) I
362380
end
@@ -402,6 +420,7 @@ k = 1
402420
d_A = CuArray(A)
403421
Eig = eigen(LinearAlgebra.Hermitian(A))
404422
d_eig = eigen(d_A)
423+
sorteig!(d_eig.values, d_eig.vectors)
405424
@test Eig.values collect(d_eig.values)
406425
d_eig = eigen(LinearAlgebra.Hermitian(d_A))
407426
@test Eig.values collect(d_eig.values)
@@ -420,6 +439,7 @@ k = 1
420439
d_A = CuArray(A)
421440
W = eigvals(LinearAlgebra.Hermitian(A))
422441
d_W = eigvals(d_A)
442+
sorteig!(d_W)
423443
@test W collect(d_W)
424444
d_W = eigvals(LinearAlgebra.Hermitian(d_A))
425445
@test W collect(d_W)
@@ -433,7 +453,9 @@ k = 1
433453
A += A'
434454
d_A = CuArray(A)
435455
V = eigvecs(LinearAlgebra.Hermitian(A))
456+
d_W = eigvals(d_A)
436457
d_V = eigvecs(d_A)
458+
sorteig!(d_W, d_V)
437459
h_V = collect(d_V)
438460
@test abs.(V'*h_V) I
439461
d_V = eigvecs(LinearAlgebra.Hermitian(d_A))

0 commit comments

Comments
 (0)