Skip to content

Commit 9204e34

Browse files
authored
Preserve array buffer type in several linalg routines (#2534)
1 parent bd4effd commit 9204e34

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

lib/cusolver/dense.jl

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
169169

170170
function geqrf!(A::StridedCuMatrix{$elty})
171171
m, n = size(A)
172-
tau = CuArray{$elty}(undef, min(m, n))
172+
tau = similar(A, $elty, min(m,n))
173173
geqrf!(A, tau)
174174
end
175175
end
@@ -361,10 +361,10 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
361361
end
362362

363363
k = min(m, n)
364-
D = CuArray{$relty}(undef, k)
365-
E = CUDA.zeros($relty, k)
366-
TAUQ = CuArray{$elty}(undef, k)
367-
TAUP = CuArray{$elty}(undef, k)
364+
D = similar(A, $relty, k)
365+
E = fill!(similar(A, $relty, k), zero($relty))
366+
TAUQ = similar(A, $elty, k)
367+
TAUP = similar(A, $elty, k)
368368

369369
with_workspace(dh.workspace_gpu, bufferSize) do buffer
370370
$fname(dh, m, n, A, lda, D, E, TAUQ, TAUP,
@@ -392,22 +392,22 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
392392
lda = max(1, stride(A, 2))
393393

394394
U = if jobu === 'A'
395-
CuArray{$elty}(undef, m, m)
395+
similar(A, $elty, (m, m))
396396
elseif jobu == 'S' || jobu === 'O'
397-
CuArray{$elty}(undef, m, min(m, n))
397+
similar(A, $elty, (m, min(m, n)))
398398
elseif jobu === 'N'
399399
CU_NULL
400400
else
401401
error("jobu must be one of 'A', 'S', 'O', or 'N'")
402402
end
403403
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
404404

405-
S = CuArray{$relty}(undef, min(m, n))
405+
S = similar(A, $relty, min(m, n))
406406

407407
Vt = if jobvt === 'A'
408-
CuArray{$elty}(undef, n, n)
408+
similar(A, $elty, (n, n))
409409
elseif jobvt === 'S' || jobvt === 'O'
410-
CuArray{$elty}(undef, min(m, n), n)
410+
similar(A, $elty, (min(m, n), n))
411411
elseif jobvt === 'N'
412412
CU_NULL
413413
else
@@ -449,22 +449,21 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS
449449
max_sweeps::Int=100)
450450
m,n = size(A)
451451
lda = max(1, stride(A, 2))
452-
453452
# Warning! For some reason, the solver needs to access U and V even
454453
# when only the values are requested
455454
U = if jobz === 'V' && econ == 1 && m > n
456-
CuArray{$elty}(undef, m, n)
455+
similar(A, $elty, (m ,n))
457456
else
458-
CuArray{$elty}(undef, m, m)
457+
similar(A, $elty, (m, m))
459458
end
460459
ldu = max(1, stride(U, 2))
461460

462-
S = CuArray{$relty}(undef, min(m, n))
461+
S = similar(A, $relty, min(m,n))
463462

464463
V = if jobz === 'V' && econ == 1 && m < n
465-
CuArray{$elty}(undef, n, m)
464+
similar(A, $elty, (n, m))
466465
else
467-
CuArray{$elty}(undef, n, n)
466+
similar(A, $elty, (n, n))
468467
end
469468
ldv = max(1, stride(V, 2))
470469

@@ -510,13 +509,13 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdjBatched_bufferSize, :cuso
510509
throw(ArgumentError("CUSOLVER's gesvdjBatched currently requires m <=32 and n <= 32"))
511510
end
512511
lda = max(1, stride(A, 2))
513-
514-
U = CuArray{$elty}(undef, m, m, batchSize)
512+
513+
U = similar(A, $elty, (m, m, batchSize))
515514
ldu = max(1, stride(U, 2))
516515

517-
S = CuArray{$relty}(undef, min(m, n), batchSize)
516+
S = similar(A, $relty, (min(m, n), batchSize))
518517

519-
V = CuArray{$elty}(undef, n, n, batchSize)
518+
V = similar(A, $elty, n, n, batchSize)
520519
ldv = max(1, stride(V, 2))
521520

522521
params = Ref{gesvdjInfo_t}(C_NULL)
@@ -569,14 +568,14 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdaStridedBatched_bufferSize
569568
lda = max(1, stride(A, 2))
570569
strideA = stride(A, 3)
571570

572-
U = CuArray{$elty}(undef, m, rank, batchSize)
571+
U = similar(A, $elty, (m, rank, batchSize))
573572
ldu = max(1, stride(U, 2))
574573
strideU = stride(U, 3)
575574

576-
S = CuArray{$relty}(undef, rank, batchSize)
575+
S = similar(A, $relty, (rank, batchSize))
577576
strideS = stride(S, 2)
578577

579-
V = CuArray{$elty}(undef, n, rank, batchSize)
578+
V = similar(A, $elty, (n, rank, batchSize))
580579
ldv = max(1, stride(V, 2))
581580
strideV = stride(V, 3)
582581

0 commit comments

Comments
 (0)