Skip to content

Commit a50d4aa

Browse files
authored
CUSOLVER: Explicitly pass NULL when not requesting svd outputs. (#1934)
1 parent 75c0ba8 commit a50d4aa

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

lib/cusolver/dense.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,12 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32),
217217
trans = $elty <: Real && trans == 'C' ? 'T' : trans
218218

219219
chktrans(trans)
220-
n = size(A, 1)
221-
if size(A, 2) != n
222-
throw(DimensionMismatch("LU factored matrix A must be square!"))
223-
end
220+
n = checksquare(A)
224221
if size(B, 1) != n
225-
throw(DimensionMismatch("first dimension of B, $(size(B,1)), must match second dimension of A, $n"))
222+
throw(DimensionMismatch("first dimension of B, $(size(B,1)), must match dimension of A, $n"))
223+
end
224+
if length(ipiv) != n
225+
throw(DimensionMismatch("length of ipiv, $(length(ipiv)), must match dimension of A, $n"))
226226
end
227227
nrhs = size(B, 2)
228228
lda = max(1, stride(A, 2))
@@ -379,28 +379,37 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
379379
m, n = size(A)
380380
if m < n
381381
throw(ArgumentError("CUSOLVER's gesvd currently requires m >= n"))
382+
# XXX: is this documented?
382383
end
383384
lda = max(1, stride(A, 2))
384385

385-
U = if jobu === 'S' && m > n
386-
CuArray{$elty}(undef, m, n)
386+
U = if jobu === 'A'
387+
CuArray{$elty}(undef, m, m)
388+
elseif jobu == 'S'
389+
CuArray{$elty}(undef, m, min(m, n))
390+
elseif jobu === 'O'
391+
CuArray{$elty}(undef, m, min(m, n))
387392
elseif jobu === 'N'
388-
CuArray{$elty}(undef, m, 0)
393+
CU_NULL
389394
else
390-
CuArray{$elty}(undef, m, m)
395+
error("jobu must be one of 'A', 'S', 'O', or 'N'")
391396
end
392-
ldu = max(1, stride(U, 2))
397+
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
393398

394399
S = CuArray{$relty}(undef, min(m, n))
395400

396-
Vt = if jobvt === 'S' && m < n
397-
CuArray{$elty}(undef, m, n)
401+
Vt = if jobvt === 'A'
402+
CuArray{$elty}(undef, n, n)
403+
elseif jobu === 'S'
404+
CuArray{$elty}(undef, min(m, n), n)
405+
elseif jobu === 'O'
406+
CuArray{$elty}(undef, min(m, n), n)
398407
elseif jobvt === 'N'
399-
CuArray{$elty}(undef, 0, n)
408+
CU_NULL
400409
else
401-
CuArray{$elty}(undef, n, n)
410+
error("jobvt must be one of 'A', 'S', 'O', or 'N'")
402411
end
403-
ldvt = max(1, stride(Vt, 2))
412+
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
404413

405414
function bufferSize()
406415
out = Ref{Cint}(0)
@@ -669,6 +678,8 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
669678
chkargsok(BlasInt(info[i]))
670679
end
671680

681+
cusolverDnDestroySyevjInfo(params[])
682+
672683
# Return eigenvalues (in W) and possibly eigenvectors (in A)
673684
if jobz == 'N'
674685
return W

0 commit comments

Comments
 (0)