Skip to content

Commit e4eecb7

Browse files
nikopjJanjusevicJanjusevic
authored
Batched SVD support (gesvdjBatched and gesvdaStridedBatched) (#2063)
Co-authored-by: Janjusevic <npj226@gn-0002.cm.cluster> Co-authored-by: Janjusevic <npj226@gn-0003.cm.cluster>
1 parent 4dcc147 commit e4eecb7

File tree

3 files changed

+174
-7
lines changed

3 files changed

+174
-7
lines changed

lib/cusolver/dense.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,120 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS
494494
end
495495
end
496496

497+
for (bname, fname, elty, relty) in ((:cusolverDnSgesvdjBatched_bufferSize, :cusolverDnSgesvdjBatched, :Float32, :Float32),
498+
(:cusolverDnDgesvdjBatched_bufferSize, :cusolverDnDgesvdjBatched, :Float64, :Float64),
499+
(:cusolverDnCgesvdjBatched_bufferSize, :cusolverDnCgesvdjBatched, :ComplexF32, :Float32),
500+
(:cusolverDnZgesvdjBatched_bufferSize, :cusolverDnZgesvdjBatched, :ComplexF64, :Float64))
501+
@eval begin
502+
function gesvdj!(jobz::Char,
503+
A::StridedCuArray{$elty,3};
504+
tol::$relty=eps($relty),
505+
max_sweeps::Int=100)
506+
m, n, batchSize = size(A)
507+
if m > 32 || n > 32
508+
throw(ArgumentError("CUSOLVER's gesvdjBatched currently requires m <=32 and n <= 32"))
509+
end
510+
lda = max(1, stride(A, 2))
511+
512+
U = CuArray{$elty}(undef, m, m, batchSize)
513+
ldu = max(1, stride(U, 2))
514+
515+
S = CuArray{$relty}(undef, min(m, n), batchSize)
516+
517+
V = CuArray{$elty}(undef, n, n, batchSize)
518+
ldv = max(1, stride(V, 2))
519+
520+
params = Ref{gesvdjInfo_t}(C_NULL)
521+
cusolverDnCreateGesvdjInfo(params)
522+
cusolverDnXgesvdjSetTolerance(params[], tol)
523+
cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps)
524+
525+
function bufferSize()
526+
out = Ref{Cint}(0)
527+
$bname(dense_handle(), jobz, m, n, A, lda, S, U, ldu, V, ldv,
528+
out, params[], batchSize)
529+
return out[]
530+
end
531+
532+
devinfo = CuArray{Cint}(undef, batchSize)
533+
with_workspace($elty, bufferSize) do work
534+
$fname(dense_handle(), jobz, m, n, A, lda, S, U, ldu, V, ldv,
535+
work, length(work), devinfo, params[], batchSize)
536+
end
537+
538+
info = @allowscalar collect(devinfo)
539+
unsafe_free!(devinfo)
540+
541+
# Double check the solver's exit status
542+
for i = 1:batchSize
543+
chkargsok(BlasInt(info[i]))
544+
end
545+
546+
cusolverDnDestroyGesvdjInfo(params[])
547+
548+
U, S, V
549+
end
550+
end
551+
end
552+
553+
for (bname, fname, elty, relty) in ((:cusolverDnSgesvdaStridedBatched_bufferSize, :cusolverDnSgesvdaStridedBatched, :Float32, :Float32),
554+
(:cusolverDnDgesvdaStridedBatched_bufferSize, :cusolverDnDgesvdaStridedBatched, :Float64, :Float64),
555+
(:cusolverDnCgesvdaStridedBatched_bufferSize, :cusolverDnCgesvdaStridedBatched, :ComplexF32, :Float32),
556+
(:cusolverDnZgesvdaStridedBatched_bufferSize, :cusolverDnZgesvdaStridedBatched, :ComplexF64, :Float64))
557+
@eval begin
558+
function gesvda!(jobz::Char,
559+
A::StridedCuArray{$elty,3};
560+
rank::Int=min(size(A,1), size(A,2)))
561+
m, n, batchSize = size(A)
562+
if m < n
563+
throw(ArgumentError("CUSOLVER's gesvda currently requires m >= n"))
564+
# nikopj: I can't find the documentation for this...
565+
end
566+
lda = max(1, stride(A, 2))
567+
strideA = stride(A, 3)
568+
569+
U = CuArray{$elty}(undef, m, rank, batchSize)
570+
ldu = max(1, stride(U, 2))
571+
strideU = stride(U, 3)
572+
573+
S = CuArray{$relty}(undef, rank, batchSize)
574+
strideS = stride(S, 2)
575+
576+
V = CuArray{$elty}(undef, n, rank, batchSize)
577+
ldv = max(1, stride(V, 2))
578+
strideV = stride(V, 3)
579+
580+
function bufferSize()
581+
out = Ref{Cint}(0)
582+
$bname(dense_handle(), jobz, rank, m, n, A, lda, strideA,
583+
S, strideS, U, ldu, strideU, V, ldv, strideV,
584+
out, batchSize)
585+
return out[]
586+
end
587+
588+
devinfo = CuArray{Cint}(undef, batchSize)
589+
# residual storage
590+
h_RnrmF = Array{Cdouble}(undef, batchSize)
591+
592+
with_workspace($elty, bufferSize) do work
593+
$fname(dense_handle(), jobz, rank, m, n, A, lda, strideA,
594+
S, strideS, U, ldu, strideU, V, ldv, strideV,
595+
work, length(work), devinfo, h_RnrmF, batchSize)
596+
end
597+
598+
info = @allowscalar collect(devinfo)
599+
unsafe_free!(devinfo)
600+
601+
# Double check the solver's exit status
602+
for i = 1:batchSize
603+
chkargsok(BlasInt(info[i]))
604+
end
605+
606+
U, S, V
607+
end
608+
end
609+
end
610+
497611
for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSize, :cusolverDnSsyevd, :Float32, :Float32),
498612
(:syevd!, :cusolverDnDsyevd_bufferSize, :cusolverDnDsyevd, :Float64, :Float64),
499613
(:heevd!, :cusolverDnCheevd_bufferSize, :cusolverDnCheevd, :ComplexF32, :Float32),

lib/cusolver/linalg.jl

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,17 @@ LinearAlgebra.rmul!(A::CuVecOrMat{T},
231231
abstract type SVDAlgorithm end
232232
struct QRAlgorithm <: SVDAlgorithm end
233233
struct JacobiAlgorithm <: SVDAlgorithm end
234+
struct ApproximateAlgorithm <: SVDAlgorithm end
234235

235-
LinearAlgebra.svd!(A::CuMatrix{T}; full::Bool=false,
236+
const CuMatOrBatched{T} = Union{CuMatrix{T}, CuArray{T,3}} where T
237+
238+
LinearAlgebra.svd!(A::CuMatOrBatched{T}; full::Bool=false,
236239
alg::SVDAlgorithm=JacobiAlgorithm()) where {T} =
237240
_svd!(A, full, alg)
238-
LinearAlgebra.svd(A::CuMatrix; full=false, alg::SVDAlgorithm=JacobiAlgorithm()) =
241+
LinearAlgebra.svd(A::CuMatOrBatched; full=false, alg::SVDAlgorithm=JacobiAlgorithm()) =
239242
_svd!(copy_cublasfloat(A), full, alg)
240243

241-
_svd!(A::CuMatrix{T}, full::Bool, alg::SVDAlgorithm) where T =
244+
_svd!(A::CuMatOrBatched, full::Bool, alg::SVDAlgorithm) =
242245
throw(ArgumentError("Unsupported value for `alg` keyword."))
243246
function _svd!(A::CuMatrix{T}, full::Bool, alg::QRAlgorithm) where T
244247
U, S, Vt = gesvd!(full ? 'A' : 'S', full ? 'A' : 'S', A)
@@ -248,16 +251,39 @@ function _svd!(A::CuMatrix{T}, full::Bool, alg::JacobiAlgorithm) where T
248251
U, S, V = gesvdj!('V', Int(!full), A)
249252
return SVD(U, S, V')
250253
end
254+
function _svd!(A::CuArray{T,3}, full::Bool, alg::JacobiAlgorithm) where T
255+
U, S, V = gesvdj!('V', A)
256+
return CuSVDBatched(U, S, V)
257+
end
258+
259+
function _svd!(A::CuArray{T,3}, full::Bool, alg::ApproximateAlgorithm; rank::Int=min(size(A,1), size(A,2))) where T
260+
U, S, V = gesvda!('V', A; rank=rank)
261+
return CuSVDBatched(U, S, V)
262+
end
263+
264+
struct CuSVDBatched{T,Tr,A<:AbstractArray{T,3}} <: LinearAlgebra.Factorization{T}
265+
U::CuArray{T,3}
266+
S::CuMatrix{Tr}
267+
V::A
268+
end
269+
270+
# iteration for destructuring into components
271+
Base.iterate(S::CuSVDBatched) = (S.U, Val(:S))
272+
Base.iterate(S::CuSVDBatched, ::Val{:S}) = (S.S, Val(:V))
273+
Base.iterate(S::CuSVDBatched, ::Val{:V}) = (S.V, Val(:done))
274+
Base.iterate(S::CuSVDBatched, ::Val{:done}) = nothing
251275

252-
LinearAlgebra.svdvals!(A::CuMatrix{T}; alg::SVDAlgorithm=JacobiAlgorithm()) where {T} =
276+
LinearAlgebra.svdvals!(A::CuMatOrBatched{T}; alg::SVDAlgorithm=JacobiAlgorithm()) where {T} =
253277
_svdvals!(A, alg)
254-
LinearAlgebra.svdvals(A::CuMatrix; alg::SVDAlgorithm=JacobiAlgorithm()) =
278+
LinearAlgebra.svdvals(A::CuMatOrBatched; alg::SVDAlgorithm=JacobiAlgorithm()) =
255279
_svdvals!(copy_cublasfloat(A), alg)
256280

257-
_svdvals!(A::CuMatrix{T}, alg::SVDAlgorithm) where T =
281+
_svdvals!(A::CuMatOrBatched{T}, alg::SVDAlgorithm) where T =
258282
throw(ArgumentError("Unsupported value for `alg` keyword."))
259283
_svdvals!(A::CuMatrix{T}, alg::QRAlgorithm) where T = gesvd!('N', 'N', A::CuMatrix{T})[2]
260-
_svdvals!(A::CuMatrix{T}, alg::JacobiAlgorithm) where T = gesvdj!('N', 1, A::CuMatrix{T})[2]
284+
_svdvals!(A::CuMatrix{T}, alg::JacobiAlgorithm) where T = gesvdj!('N', 1, A::CuMatOrBatched{T})[2]
285+
_svdvals!(A::CuArray{T,3}, alg::JacobiAlgorithm) where T = gesvdj!('N', A::CuArray{T,3})[2]
286+
_svdvals!(A::CuArray{T,3}, alg::ApproximateAlgorithm; rank=min(size(A,1), size(A,2))) where T = gesvda!('N', A::CuArray{T,3}; rank=rank)[2]
261287

262288
### opnorm2, enabled by svdvals
263289

test/libraries/cusolver/dense.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,33 @@ k = 1
328328
_svd(A) = svd(A; alg=CUSOLVER.QRAlgorithm())
329329
@inferred _svd(CUDA.rand(Float32, 4, 4))
330330

331+
@testset "batched $svd_f with $alg algorithm" for
332+
svd_f in (svd, svd!),
333+
alg in (CUSOLVER.JacobiAlgorithm(), CUSOLVER.ApproximateAlgorithm()),
334+
(_m, _n, _b) in ((m, n, n), (n, m, n), (33,33,1))
335+
336+
A = rand(elty, _m, _n, _b)
337+
d_A = CuArray(A)
338+
r = min(_m, _n)
339+
340+
if (_m >= _n && alg == CUSOLVER.ApproximateAlgorithm()) || (_m <= 32 && _n <= 32 && alg == CUSOLVER.JacobiAlgorithm())
341+
d_U, d_S, d_V = svd_f(copy(d_A); full=true, alg=alg)
342+
h_S = collect(d_S)
343+
h_U = collect(d_U)
344+
h_V = collect(d_V)
345+
for i=1:_b
346+
U, S, V = svd(A[:,:,i]; full=true)
347+
@test abs.(h_U[:,:,i]'*h_U[:,:,i]) I
348+
@test abs.(h_U[:,1:min(_m,_n),i]'U[:,1:min(_m,_n)]) I
349+
@test collect(svdvals(d_A; alg=alg))[:,i] svdvals(A[:,:,i])
350+
@test abs.(h_V[:,:,i]'*h_V[:,:,i]) I
351+
@test collect(d_U[:,:,i]'*d_A[:,:,i]*d_V[:,:,i])[1:r,1:r] (U'*A[:,:,i]*V)[1:r,1:r]
352+
end
353+
else
354+
@test_throws ArgumentError svd(d_A; alg=alg)
355+
end
356+
end
357+
331358
@testset "2-opnorm($sz x $elty)" for sz in [(2, 0), (2, 3)]
332359
A = rand(elty, sz)
333360
d_A = CuArray(A)

0 commit comments

Comments
 (0)