Skip to content

Commit f67bfbd

Browse files
amontoisonmaleadt
authored andcommitted
[CUSOLVER] Add more tests for the dense SVD
1 parent 6bdb146 commit f67bfbd

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

lib/cusolver/dense.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -389,31 +389,30 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
389389
A::StridedCuMatrix{$elty})
390390
m, n = size(A)
391391
(m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n"))
392+
k = min(m, n)
392393
lda = max(1, stride(A, 2))
393394

394395
U = if jobu === 'A'
395396
similar(A, $elty, (m, m))
396397
elseif jobu == 'S' || jobu === 'O'
397-
similar(A, $elty, (m, min(m, n)))
398+
similar(A, $elty, (m, k))
398399
elseif jobu === 'N'
399400
CU_NULL
400401
else
401402
error("jobu must be one of 'A', 'S', 'O', or 'N'")
402403
end
403-
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
404-
405-
S = similar(A, $relty, min(m, n))
406-
404+
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
405+
S = similar(A, $relty, k)
407406
Vt = if jobvt === 'A'
408407
similar(A, $elty, (n, n))
409408
elseif jobvt === 'S' || jobvt === 'O'
410-
similar(A, $elty, (min(m, n), n))
409+
similar(A, $elty, k, n))
411410
elseif jobvt === 'N'
412411
CU_NULL
413412
else
414413
error("jobvt must be one of 'A', 'S', 'O', or 'N'")
415414
end
416-
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
415+
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
417416
dh = dense_handle()
418417

419418
function bufferSize()

test/libraries/cusolver/dense.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,21 @@ k = 1
194194
@test p h_TAUP
195195
end
196196

197+
@testset "gesvd!" begin
198+
A = rand(elty,m,n)
199+
d_A = CuMatrix(A)
200+
U, Σ, Vt = CUSOLVER.gesvd!('A', 'A', d_A)
201+
@test A collect(U[:,1:n] * Diagonal(Σ) * Vt)
202+
203+
for jobu in ('A', 'S', 'N', 'O')
204+
for jobvt in ('A', 'S', 'N', 'O')
205+
d_A = CuMatrix(A)
206+
U2, Σ2, Vt2 = CUSOLVER.gesvd!(jobu, jobvt, d_A)
207+
@test Σ Σ2
208+
end
209+
end
210+
end
211+
197212
@testset "syevd!" begin
198213
A = rand(elty,m,m)
199214
A += A'

test/libraries/cusolver/dense_generic.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,14 @@ p = 5
149149
d_A = CuMatrix(A)
150150
U, Σ, Vt = CUSOLVER.Xgesvd!('A', 'A', d_A)
151151
@test A collect(U[:,1:n] * Diagonal(Σ) * Vt)
152+
153+
for jobu in ('A', 'S', 'N', 'O')
154+
for jobvt in ('A', 'S', 'N', 'O')
155+
d_A = CuMatrix(A)
156+
U2, Σ2, Vt2 = CUSOLVER.Xgesvd!(jobu, jobvt, d_A)
157+
@test Σ Σ2
158+
end
159+
end
152160
end
153161

154162
@testset "gesvdp!" begin

0 commit comments

Comments
 (0)