Skip to content

Commit ca8f6cf

Browse files
authored
[CUSOLVER] Interface XsyevBatched (#2577)
1 parent 860eb88 commit ca8f6cf

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

lib/cusolver/dense_generic.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ end
190190

191191
# Xlarft!
192192
function larft!(direct::Char, storev::Char, v::StridedCuMatrix{T}, tau::StridedCuVector{T}, t::StridedCuMatrix{T}) where {T <: BlasFloat}
193+
CUSOLVER.version() < v"11.6.0" && throw(ErrorException("This operation is not supported by the current CUDA version."))
193194
n, k = size(v)
194195
ktau = length(tau)
195196
mt, nt = size(t)
@@ -449,6 +450,7 @@ end
449450

450451
# Xgeev
451452
function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
453+
CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
452454
n = checksquare(A)
453455
VL = if jobvl == 'V'
454456
CuMatrix{T}(undef, n, n)
@@ -492,6 +494,44 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla
492494
return W, VL, VR
493495
end
494496

497+
# XsyevBatched
498+
function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
499+
CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
500+
chkuplo(uplo)
501+
n, num_matrices = size(A)
502+
batch_size = num_matrices ÷ n
503+
R = real(T)
504+
lda = max(1, stride(A, 2))
505+
W = CuVector{R}(undef, n * batch_size)
506+
params = CuSolverParameters()
507+
dh = dense_handle()
508+
resize!(dh.info, batch_size)
509+
510+
function bufferSize()
511+
out_cpu = Ref{Csize_t}(0)
512+
out_gpu = Ref{Csize_t}(0)
513+
cusolverDnXsyevBatched_bufferSize(dh, params, jobz, uplo, n,
514+
T, A, lda, R, W, T, out_gpu, out_cpu, batch_size)
515+
out_gpu[], out_cpu[]
516+
end
517+
with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu
518+
cusolverDnXsyevBatched(dh, params, jobz, uplo, n, T, A,
519+
lda, R, W, T, buffer_gpu, sizeof(buffer_gpu),
520+
buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size)
521+
end
522+
523+
info = @allowscalar collect(dh.info)
524+
for i = 1:batch_size
525+
chkargsok(info[i] |> BlasInt)
526+
end
527+
528+
if jobz == 'N'
529+
return W
530+
elseif jobz == 'V'
531+
return W, A
532+
end
533+
end
534+
495535
# LAPACK
496536
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
497537
@eval begin

test/libraries/cusolver/dense_generic.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,36 @@ p = 5
3131
end
3232
end
3333
end
34+
35+
@testset "syevBatched!" begin
36+
batch_size = 5
37+
for uplo in ('L', 'U')
38+
(uplo == 'L') && (elty == ComplexF32) && continue
39+
40+
A = rand(elty, n, n * batch_size)
41+
B = rand(elty, n, n * batch_size)
42+
for i = 1:batch_size
43+
S = rand(elty,n,n)
44+
S = S * S' + I
45+
B[:,(i-1)*n+1:i*n] .= S
46+
S = uplo == 'L' ? tril(S) : triu(S)
47+
A[:,(i-1)*n+1:i*n] .= S
48+
end
49+
d_A = CuMatrix(A)
50+
d_W, d_V = CUSOLVER.XsyevBatched!('V', uplo, d_A)
51+
W = collect(d_W)
52+
V = collect(d_V)
53+
for i = 1:batch_size
54+
Bᵢ = B[:,(i-1)*n+1:i*n]
55+
Wᵢ = Diagonal(W[(i-1)*n+1:i*n])
56+
Vᵢ = V[:,(i-1)*n+1:i*n]
57+
@test Bᵢ * Vᵢ Vᵢ * Diagonal(Wᵢ)
58+
end
59+
60+
d_A = CuMatrix(A)
61+
d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A)
62+
end
63+
end
3464
end
3565

3666
if CUSOLVER.version() >= v"11.6.0"

0 commit comments

Comments
 (0)