@@ -217,12 +217,12 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32),
217
217
trans = $ elty <: Real && trans == ' C' ? ' T' : trans
218
218
219
219
chktrans (trans)
220
- n = size (A, 1 )
221
- if size (A, 2 ) != n
222
- throw (DimensionMismatch (" LU factored matrix A must be square!" ))
223
- end
220
+ n = checksquare (A)
224
221
if size (B, 1 ) != n
225
- throw (DimensionMismatch (" first dimension of B, $(size (B,1 )) , must match second dimension of A, $n " ))
222
+ throw (DimensionMismatch (" first dimension of B, $(size (B,1 )) , must match dimension of A, $n " ))
223
+ end
224
+ if length (ipiv) != n
225
+ throw (DimensionMismatch (" length of ipiv, $(length (ipiv)) , must match dimension of A, $n " ))
226
226
end
227
227
nrhs = size (B, 2 )
228
228
lda = max (1 , stride (A, 2 ))
@@ -379,28 +379,37 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
379
379
m, n = size (A)
380
380
if m < n
381
381
throw (ArgumentError (" CUSOLVER's gesvd currently requires m >= n" ))
382
+ # XXX : is this documented?
382
383
end
383
384
lda = max (1 , stride (A, 2 ))
384
385
385
- U = if jobu === ' S' && m > n
386
- CuArray {$elty} (undef, m, n)
386
+ U = if jobu === ' A'
387
+ CuArray {$elty} (undef, m, m)
388
+ elseif jobu == ' S'
389
+ CuArray {$elty} (undef, m, min (m, n))
390
+ elseif jobu === ' O'
391
+ CuArray {$elty} (undef, m, min (m, n))
387
392
elseif jobu === ' N'
388
- CuArray {$elty} (undef, m, 0 )
393
+ CU_NULL
389
394
else
390
- CuArray {$elty} (undef, m, m )
395
+ error ( " jobu must be one of 'A', 'S', 'O', or 'N' " )
391
396
end
392
- ldu = max (1 , stride (U, 2 ))
397
+ ldu = U == CU_NULL ? 1 : max (1 , stride (U, 2 ))
393
398
394
399
S = CuArray {$relty} (undef, min (m, n))
395
400
396
- Vt = if jobvt === ' S' && m < n
397
- CuArray {$elty} (undef, m, n)
401
+ Vt = if jobvt === ' A'
402
+ CuArray {$elty} (undef, n, n)
403
+ elseif jobu === ' S'
404
+ CuArray {$elty} (undef, min (m, n), n)
405
+ elseif jobu === ' O'
406
+ CuArray {$elty} (undef, min (m, n), n)
398
407
elseif jobvt === ' N'
399
- CuArray {$elty} (undef, 0 , n)
408
+ CU_NULL
400
409
else
401
- CuArray {$elty} (undef, n, n )
410
+ error ( " jobvt must be one of 'A', 'S', 'O', or 'N' " )
402
411
end
403
- ldvt = max (1 , stride (Vt, 2 ))
412
+ ldvt = Vt == CU_NULL ? 1 : max (1 , stride (Vt, 2 ))
404
413
405
414
function bufferSize ()
406
415
out = Ref {Cint} (0 )
@@ -669,6 +678,8 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
669
678
chkargsok (BlasInt (info[i]))
670
679
end
671
680
681
+ cusolverDnDestroySyevjInfo (params[])
682
+
672
683
# Return eigenvalues (in W) and possibly eigenvectors (in A)
673
684
if jobz == ' N'
674
685
return W
0 commit comments