@@ -169,7 +169,7 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
169
169
170
170
function geqrf! (A:: StridedCuMatrix{$elty} )
171
171
m, n = size (A)
172
- tau = CuArray { $elty} (undef , min (m, n))
172
+ tau = similar (A, $ elty, min (m,n))
173
173
geqrf! (A, tau)
174
174
end
175
175
end
@@ -361,10 +361,10 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
361
361
end
362
362
363
363
k = min (m, n)
364
- D = CuArray { $relty} (undef , k)
365
- E = CUDA . zeros ( $ relty, k)
366
- TAUQ = CuArray { $elty} (undef , k)
367
- TAUP = CuArray { $elty} (undef , k)
364
+ D = similar (A, $ relty, k)
365
+ E = fill! ( similar (A, $ relty, k), zero ( $ relty) )
366
+ TAUQ = similar (A, $ elty, k)
367
+ TAUP = similar (A, $ elty, k)
368
368
369
369
with_workspace (dh. workspace_gpu, bufferSize) do buffer
370
370
$ fname (dh, m, n, A, lda, D, E, TAUQ, TAUP,
@@ -392,22 +392,22 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
392
392
lda = max (1 , stride (A, 2 ))
393
393
394
394
U = if jobu === ' A'
395
- CuArray { $elty} (undef, m, m)
395
+ similar (A, $ elty, ( m, m) )
396
396
elseif jobu == ' S' || jobu === ' O'
397
- CuArray { $elty} (undef, m, min (m, n))
397
+ similar (A, $ elty, ( m, min (m, n) ))
398
398
elseif jobu === ' N'
399
399
CU_NULL
400
400
else
401
401
error (" jobu must be one of 'A', 'S', 'O', or 'N'" )
402
402
end
403
403
ldu = U == CU_NULL ? 1 : max (1 , stride (U, 2 ))
404
404
405
- S = CuArray { $relty} (undef , min (m, n))
405
+ S = similar (A, $ relty, min (m, n))
406
406
407
407
Vt = if jobvt === ' A'
408
- CuArray { $elty} (undef, n, n)
408
+ similar (A, $ elty, ( n, n) )
409
409
elseif jobvt === ' S' || jobvt === ' O'
410
- CuArray { $elty} (undef, min (m, n), n)
410
+ similar (A, $ elty, ( min (m, n), n) )
411
411
elseif jobvt === ' N'
412
412
CU_NULL
413
413
else
@@ -449,22 +449,21 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS
449
449
max_sweeps:: Int = 100 )
450
450
m,n = size (A)
451
451
lda = max (1 , stride (A, 2 ))
452
-
453
452
# Warning! For some reason, the solver needs to access U and V even
454
453
# when only the values are requested
455
454
U = if jobz === ' V' && econ == 1 && m > n
456
- CuArray { $elty} (undef, m, n )
455
+ similar (A, $ elty, (m ,n) )
457
456
else
458
- CuArray { $elty} (undef, m, m)
457
+ similar (A, $ elty, ( m, m) )
459
458
end
460
459
ldu = max (1 , stride (U, 2 ))
461
460
462
- S = CuArray { $relty} (undef , min (m, n))
461
+ S = similar (A, $ relty, min (m,n))
463
462
464
463
V = if jobz === ' V' && econ == 1 && m < n
465
- CuArray { $elty} (undef, n, m)
464
+ similar (A, $ elty, ( n, m) )
466
465
else
467
- CuArray { $elty} (undef, n, n)
466
+ similar (A, $ elty, ( n, n) )
468
467
end
469
468
ldv = max (1 , stride (V, 2 ))
470
469
@@ -510,13 +509,13 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdjBatched_bufferSize, :cuso
510
509
throw (ArgumentError (" CUSOLVER's gesvdjBatched currently requires m <=32 and n <= 32" ))
511
510
end
512
511
lda = max (1 , stride (A, 2 ))
513
-
514
- U = CuArray { $elty} (undef, m, m, batchSize)
512
+
513
+ U = similar (A, $ elty, ( m, m, batchSize) )
515
514
ldu = max (1 , stride (U, 2 ))
516
515
517
- S = CuArray { $relty} (undef, min (m, n), batchSize)
516
+ S = similar (A, $ relty, ( min (m, n), batchSize) )
518
517
519
- V = CuArray { $elty} (undef , n, n, batchSize)
518
+ V = similar (A, $ elty, n, n, batchSize)
520
519
ldv = max (1 , stride (V, 2 ))
521
520
522
521
params = Ref {gesvdjInfo_t} (C_NULL )
@@ -569,14 +568,14 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdaStridedBatched_bufferSize
569
568
lda = max (1 , stride (A, 2 ))
570
569
strideA = stride (A, 3 )
571
570
572
- U = CuArray { $elty} (undef, m, rank, batchSize)
571
+ U = similar (A, $ elty, ( m, rank, batchSize) )
573
572
ldu = max (1 , stride (U, 2 ))
574
573
strideU = stride (U, 3 )
575
574
576
- S = CuArray { $relty} (undef, rank, batchSize)
575
+ S = similar (A, $ relty, ( rank, batchSize) )
577
576
strideS = stride (S, 2 )
578
577
579
- V = CuArray { $elty} (undef, n, rank, batchSize)
578
+ V = similar (A, $ elty, ( n, rank, batchSize) )
580
579
ldv = max (1 , stride (V, 2 ))
581
580
strideV = stride (V, 3 )
582
581
0 commit comments