Skip to content

Commit 66ab572

Browse files
authored
Remove shape-preserving Diagonal conversion constructors. (#2805)
They should always collect the diagonal and return a dense matrix. For preserving conversions, use Adapt.
1 parent 740e888 commit 66ab572

File tree

10 files changed

+41
-20
lines changed

10 files changed

+41
-20
lines changed

lib/cublas/CUBLAS.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ using LLVM.Interop: assume
2020

2121
using CEnum: @cenum
2222

23+
using Adapt: adapt
24+
2325

2426
const cudaDataType_t = cudaDataType
2527

lib/cublas/linalg.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,35 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
328328
end
329329

330330
# Diagonal
331-
Base.Array(D::Diagonal{T, <:CuArray{T}}) where {T} = Diagonal(Array(D.diag))
332-
CuArray(D::Diagonal{T, <:Vector{T}}) where {T} = Diagonal(CuArray(D.diag))
331+
332+
# conversions to dense matrices
333+
Base.Array(D::Diagonal{T, <:CuArray{T}}) where {T} = Array(Diagonal(Array(D.diag)))
334+
CUDA.CuArray(D::Diagonal{T}) where {T} = CuMatrix(D)
335+
function CUDA.CuMatrix{T}(D::Diagonal) where {T}
336+
n = size(D, 1)
337+
B = CUDA.zeros(T, n, n)
338+
n == 0 && return B
339+
340+
gpu_diag = adapt(CuArray, D.diag)
341+
## COV_EXCL_START
342+
function fill_diagonal()
343+
i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
344+
grid_stride = gridDim().x * blockDim().x
345+
while i <= n
346+
@inbounds B[i,i] = gpu_diag[i]
347+
i += grid_stride
348+
end
349+
return
350+
end
351+
## COV_EXCL_STOP
352+
kernel = @cuda launch = false fill_diagonal()
353+
config = launch_configuration(kernel.fun)
354+
threads = min(config.threads, n)
355+
blocks = min(config.blocks, cld(n, threads))
356+
@cuda threads blocks fill_diagonal()
357+
358+
return B
359+
end
333360

334361
function LinearAlgebra.inv(D::Diagonal{T, <:CuArray{T}}) where {T}
335362
Di = map(inv, D.diag)

test/base/array.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using LinearAlgebra
2-
import Adapt
32
using ChainRulesCore: add!!, is_inplaceable_destination
43

54
@testset "constructors" begin

test/base/kernelabstractions.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import KernelAbstractions
22
import KernelAbstractions as KA
33
using SparseArrays
4-
using Adapt
54

65
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
76

test/core/execution.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import Adapt
2-
31
dummy() = return
42

53
@testset "@cuda" begin

test/libraries/cublas/level3.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ k = 13
127127
uplotype in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
128128

129129
A = triu(rand(elty, m, m))
130-
dA = CuArray(A)
130+
dA = CuArray(A)
131131
Br = rand(elty,m,n)
132132
Bl = rand(elty,n,m)
133133
d_Br = CuArray(Br)
@@ -136,16 +136,16 @@ k = 13
136136
@test Bl/adjtype(uplotype(A)) Array(d_Bl/adjtype(uplotype(dA)))
137137
end
138138
# Check also that scaling parameter works
139-
alpha = rand(elty)
139+
alpha = rand(elty)
140140
A = triu(rand(elty, m, m))
141-
dA = CuArray(A)
141+
dA = CuArray(A)
142142
Br = rand(elty,m,n)
143143
d_Br = CuArray(Br)
144144
@test BLAS.trsm('L','U','N','N',alpha,A,Br) Array(CUBLAS.trsm('L','U','N','N',alpha,dA,d_Br))
145145
end
146146

147147
@testset "trsm_batched!" begin
148-
alpha = rand(elty)
148+
alpha = rand(elty)
149149
bA = [rand(elty,m,m) for i in 1:10]
150150
map!((x) -> triu(x), bA, bA)
151151
bB = [rand(elty,m,n) for i in 1:10]
@@ -288,7 +288,7 @@ k = 13
288288
B = Diagonal(rand(elty, m))
289289

290290
dA = CuArray(A)
291-
dB = CuArray(B)
291+
dB = adapt(CuArray, B)
292292

293293
C = A / B
294294
d_C = dA / dB
@@ -297,7 +297,7 @@ k = 13
297297
@test C Array(dA)
298298
@test C h_C
299299

300-
B_bad = Diagonal(rand(elty, m+1))
300+
B_bad = Diagonal(CuArray(rand(elty, m+1)))
301301
@test_throws DimensionMismatch("left hand side has $m columns but D is $(m+1) by $(m+1)") rdiv!(dA, B_bad)
302302
end
303303

@@ -483,12 +483,12 @@ k = 13
483483

484484
@test A h_A
485485
@test B h_B
486-
486+
487487
diagA = diagm(m, m, 0 => d_A)
488488
diagind_A = diagind(diagA, 0)
489489
h_A = Array(diagA[diagind_A])
490490
@test A h_A
491-
491+
492492
diagA = diagm(m, m, d_A)
493493
diagind_A = diagind(diagA, 0)
494494
h_A = Array(diagA[diagind_A])

test/libraries/cusparse/conversions.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
using LinearAlgebra
2-
using Adapt
32
using CUDA.CUSPARSE
43
using SparseArrays
5-
using CUDA
64

75
@testset "sparse" begin
86
n, m = 4, 4

test/libraries/cusparse/generic.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using CUDA
2-
using Adapt
31
using CUDA.CUSPARSE
42
using SparseArrays
53
using LinearAlgebra

test/libraries/cusparse/interfaces.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using CUDA
2-
using Adapt
31
using CUDA.CUSPARSE
42
using LinearAlgebra, SparseArrays
53

test/setup.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ testf(f, xs...; kwargs...) = TestSuite.compare(f, CuArray, xs...; kwargs...)
1414

1515
using Random
1616

17+
using Adapt
18+
1719
# detect compute-sanitizer, to disable incompatible tests (e.g. using CUPTI)
1820
const sanitize = any(contains("NV_SANITIZER"), keys(ENV))
1921

0 commit comments

Comments
 (0)