Skip to content

Commit 478a952

Browse files
authored
Merge pull request #2574 from amontoison/test_gesvd
2 parents 2dae25b + 48ca037 commit 478a952

File tree

4 files changed

+45
-18
lines changed

4 files changed

+45
-18
lines changed

lib/cusolver/dense.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -389,31 +389,30 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
389389
A::StridedCuMatrix{$elty})
390390
m, n = size(A)
391391
(m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n"))
392+
k = min(m, n)
392393
lda = max(1, stride(A, 2))
393394

394395
U = if jobu === 'A'
395396
similar(A, $elty, (m, m))
396-
elseif jobu == 'S' || jobu === 'O'
397-
similar(A, $elty, (m, min(m, n)))
398-
elseif jobu === 'N'
397+
elseif jobu === 'S'
398+
similar(A, $elty, (m, k))
399+
elseif jobu === 'N' || jobu === 'O'
399400
CU_NULL
400401
else
401402
error("jobu must be one of 'A', 'S', 'O', or 'N'")
402403
end
403-
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
404-
405-
S = similar(A, $relty, min(m, n))
406-
404+
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
405+
S = similar(A, $relty, k)
407406
Vt = if jobvt === 'A'
408407
similar(A, $elty, (n, n))
409-
elseif jobvt === 'S' || jobvt === 'O'
410-
similar(A, $elty, (min(m, n), n))
411-
elseif jobvt === 'N'
408+
elseif jobvt === 'S'
409+
similar(A, $elty, (k, n))
410+
elseif jobvt === 'N' || jobvt === 'O'
412411
CU_NULL
413412
else
414413
error("jobvt must be one of 'A', 'S', 'O', or 'N'")
415414
end
416-
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
415+
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
417416
dh = dense_handle()
418417

419418
function bufferSize()

lib/cusolver/dense_generic.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,21 +222,22 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla
222222
m, n = size(A)
223223
R = real(T)
224224
(m < n) && throw(ArgumentError("The number of rows of A ($m) must be greater or equal to the number of columns of A ($n)"))
225+
k = min(m, n)
225226
U = if jobu == 'A'
226227
CuMatrix{T}(undef, m, m)
227-
elseif jobu == 'S' || jobu == 'O'
228-
CuMatrix{T}(undef, m, min(m, n))
229-
elseif jobu == 'N'
228+
elseif jobu == 'S'
229+
CuMatrix{T}(undef, m, k)
230+
elseif jobu == 'N' || jobu == 'O'
230231
CU_NULL
231232
else
232233
throw(ArgumentError("jobu is incorrect. The values accepted are 'A', 'S', 'O' and 'N'."))
233234
end
234-
Σ = CuVector{R}(undef, min(m, n))
235+
Σ = CuVector{R}(undef, k)
235236
Vt = if jobvt == 'A'
236237
CuMatrix{T}(undef, n, n)
237-
elseif jobvt == 'S' || jobvt == 'O'
238-
CuMatrix{T}(undef, min(m, n), n)
239-
elseif jobvt == 'N'
238+
elseif jobvt == 'S'
239+
CuMatrix{T}(undef, k, n)
240+
elseif jobvt == 'N' || jobvt == 'O'
240241
CU_NULL
241242
else
242243
throw(ArgumentError("jobvt is incorrect. The values accepted are 'A', 'S', 'O' and 'N'."))

test/libraries/cusolver/dense.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,23 @@ k = 1
194194
@test p h_TAUP
195195
end
196196

197+
@testset "gesvd!" begin
198+
A = rand(elty,m,n)
199+
d_A = CuMatrix(A)
200+
U, Σ, Vt = CUSOLVER.gesvd!('A', 'A', d_A)
201+
@test A collect(U[:,1:n] * Diagonal(Σ) * Vt)
202+
203+
for jobu in ('A', 'S', 'N', 'O')
204+
for jobvt in ('A', 'S', 'N', 'O')
205+
(jobu == 'A') && (jobvt == 'A') && continue
206+
(jobu == 'O') && (jobvt == 'O') && continue
207+
d_A = CuMatrix(A)
208+
U2, Σ2, Vt2 = CUSOLVER.gesvd!(jobu, jobvt, d_A)
209+
@test Σ Σ2
210+
end
211+
end
212+
end
213+
197214
@testset "syevd!" begin
198215
A = rand(elty,m,m)
199216
A += A'

test/libraries/cusolver/dense_generic.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ p = 5
149149
d_A = CuMatrix(A)
150150
U, Σ, Vt = CUSOLVER.Xgesvd!('A', 'A', d_A)
151151
@test A collect(U[:,1:n] * Diagonal(Σ) * Vt)
152+
153+
for jobu in ('A', 'S', 'N', 'O')
154+
for jobvt in ('A', 'S', 'N', 'O')
155+
(jobu == 'A') && (jobvt == 'A') && continue
156+
(jobu == 'O') && (jobvt == 'O') && continue
157+
d_A = CuMatrix(A)
158+
U2, Σ2, Vt2 = CUSOLVER.Xgesvd!(jobu, jobvt, d_A)
159+
@test Σ Σ2
160+
end
161+
end
152162
end
153163

154164
@testset "gesvdp!" begin

0 commit comments

Comments
 (0)