|
190 | 190 |
|
191 | 191 | # Xlarft!
|
192 | 192 | function larft!(direct::Char, storev::Char, v::StridedCuMatrix{T}, tau::StridedCuVector{T}, t::StridedCuMatrix{T}) where {T <: BlasFloat}
|
| 193 | + CUSOLVER.version() < v"11.6.0" && throw(ErrorException("This operation is not supported by the current CUDA version.")) |
193 | 194 | n, k = size(v)
|
194 | 195 | ktau = length(tau)
|
195 | 196 | mt, nt = size(t)
|
|
449 | 450 |
|
450 | 451 | # Xgeev
|
451 | 452 | function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
|
| 453 | + CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) |
452 | 454 | n = checksquare(A)
|
453 | 455 | VL = if jobvl == 'V'
|
454 | 456 | CuMatrix{T}(undef, n, n)
|
@@ -492,6 +494,44 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla
|
492 | 494 | return W, VL, VR
|
493 | 495 | end
|
494 | 496 |
|
| 497 | +# XsyevBatched |
| 498 | +function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} |
| 499 | + CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) |
| 500 | + chkuplo(uplo) |
| 501 | + n, num_matrices = size(A) |
| 502 | + batch_size = num_matrices ÷ n |
| 503 | + R = real(T) |
| 504 | + lda = max(1, stride(A, 2)) |
| 505 | + W = CuVector{R}(undef, n * batch_size) |
| 506 | + params = CuSolverParameters() |
| 507 | + dh = dense_handle() |
| 508 | + resize!(dh.info, batch_size) |
| 509 | + |
| 510 | + function bufferSize() |
| 511 | + out_cpu = Ref{Csize_t}(0) |
| 512 | + out_gpu = Ref{Csize_t}(0) |
| 513 | + cusolverDnXsyevBatched_bufferSize(dh, params, jobz, uplo, n, |
| 514 | + T, A, lda, R, W, T, out_gpu, out_cpu, batch_size) |
| 515 | + out_gpu[], out_cpu[] |
| 516 | + end |
| 517 | + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu |
| 518 | + cusolverDnXsyevBatched(dh, params, jobz, uplo, n, T, A, |
| 519 | + lda, R, W, T, buffer_gpu, sizeof(buffer_gpu), |
| 520 | + buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size) |
| 521 | + end |
| 522 | + |
| 523 | + info = @allowscalar collect(dh.info) |
| 524 | + for i = 1:batch_size |
| 525 | + chkargsok(info[i] |> BlasInt) |
| 526 | + end |
| 527 | + |
| 528 | + if jobz == 'N' |
| 529 | + return W |
| 530 | + elseif jobz == 'V' |
| 531 | + return W, A |
| 532 | + end |
| 533 | +end |
| 534 | + |
495 | 535 | # LAPACK
|
496 | 536 | for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
|
497 | 537 | @eval begin
|
|
0 commit comments