Skip to content

Commit dcbab7d

Browse files
authored
Merge pull request #1410 from JuliaGPU/tb/sparse_int
Support limited sparse integer arrays by bitcasting to floating point.
2 parents 285c7a5 + 50e92f1 commit dcbab7d

File tree

3 files changed

+99
-13
lines changed

3 files changed

+99
-13
lines changed

lib/cusparse/conversions.jl

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ function CuSparseMatrixCSR{T}(S::Adjoint{T, <:CuSparseMatrixCSC{T}}) where {T <:
7474
return CuSparseMatrixCSR{T}(csc.colPtr, csc.rowVal, conj.(csc.nzVal), size(csc))
7575
end
7676

77-
7877
# by flipping rows and columns, we can use that to get CSC to CSR too
7978
for (fname,elty) in ((:cusparseScsr2csc, :Float32),
8079
(:cusparseDcsr2csc, :Float64),
@@ -141,6 +140,8 @@ for (fname,elty) in ((:cusparseScsr2csc, :Float32),
141140
end
142141
end
143142

143+
# implement Float16 conversions using wider types
144+
# TODO: Float16 is sometimes natively supported
144145
for (elty, welty) in ((:Float16, :Float32),
145146
(:ComplexF16, :ComplexF32))
146147
@eval begin
@@ -204,6 +205,46 @@ for (elty, welty) in ((:Float16, :Float32),
204205
end
205206
end
206207

208+
# implement Int conversions using reinterpreted Float
209+
for (elty, felty) in ((:Int16, :Float16),
210+
(:Int32, :Float32),
211+
(:Int64, :Float64),
212+
(:Int128, :ComplexF64))
213+
@eval begin
214+
function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty})
215+
csc_compat = CuSparseMatrixCSC(
216+
csc.colPtr,
217+
csc.rowVal,
218+
reinterpret($felty, csc.nzVal),
219+
size(csc)
220+
)
221+
csr_compat = CuSparseMatrixCSR(csc_compat)
222+
CuSparseMatrixCSR(
223+
csr_compat.rowPtr,
224+
csr_compat.colVal,
225+
reinterpret($elty, csr_compat.nzVal),
226+
size(csr_compat)
227+
)
228+
end
229+
230+
function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty})
231+
csr_compat = CuSparseMatrixCSR(
232+
csr.rowPtr,
233+
csr.colVal,
234+
reinterpret($felty, csr.nzVal),
235+
size(csr)
236+
)
237+
csc_compat = CuSparseMatrixCSC(csr_compat)
238+
CuSparseMatrixCSC(
239+
csc_compat.colPtr,
240+
csc_compat.rowVal,
241+
reinterpret($elty, csc_compat.nzVal),
242+
size(csc_compat)
243+
)
244+
end
245+
end
246+
end
247+
207248
## CSR to BSR and vice-versa
208249

209250
for (fname,elty) in ((:cusparseScsr2bsr, :Float32),
@@ -260,6 +301,52 @@ for (fname,elty) in ((:cusparseSbsr2csr, :Float32),
260301
end
261302
end
262303

304+
# implement Int conversions using reinterpreted Float
305+
for (elty, felty) in ((:Int16, :Float16),
306+
(:Int32, :Float32),
307+
(:Int64, :Float64),
308+
(:Int128, :ComplexF64))
309+
@eval begin
310+
function CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty})
311+
bsr_compat = CuSparseMatrixBSR(
312+
bsr.rowPtr,
313+
bsr.colVal,
314+
reinterpret($felty, bsr.nzVal),
315+
bsr.blockDim,
316+
bsr.dir,
317+
bsr.nnzb,
318+
size(bsr)
319+
)
320+
csr_compat = CuSparseMatrixCSR(bsr_compat)
321+
CuSparseMatrixCSR(
322+
csr_compat.rowPtr,
323+
csr_compat.colVal,
324+
reinterpret($elty, csr_compat.nzVal),
325+
size(csr_compat)
326+
)
327+
end
328+
329+
function CuSparseMatrixBSR{$elty}(csr::CuSparseMatrixCSR{$elty}, blockDim)
330+
csr_compat = CuSparseMatrixCSR(
331+
csr.rowPtr,
332+
csr.colVal,
333+
reinterpret($felty, csr.nzVal),
334+
size(csr)
335+
)
336+
bsr_compat = CuSparseMatrixBSR(csr_compat, blockDim)
337+
CuSparseMatrixBSR(
338+
bsr_compat.rowPtr,
339+
bsr_compat.colVal,
340+
reinterpret($elty, bsr_compat.nzVal),
341+
bsr_compat.blockDim,
342+
bsr_compat.dir,
343+
bsr_compat.nnzb,
344+
size(bsr_compat)
345+
)
346+
end
347+
end
348+
end
349+
263350
## CSR to COO and vice-versa
264351

265352
function CuSparseMatrixCSR(coo::CuSparseMatrixCOO{Tv}, ind::SparseChar='O') where {Tv}

test/cusparse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ blockdim = 5
108108
end
109109

110110
@testset "construction" begin
111-
@testset for elty in [Float32, Float64, ComplexF32, ComplexF64]
111+
@testset for elty in [Int32, Int64, Float32, Float64, ComplexF32, ComplexF64]
112112
@testset "vector" begin
113113
x = sprand(elty,m, 0.2)
114114
d_x = CuSparseVector(x)

test/cusparse/broadcast.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,34 @@ using CUDA.CUSPARSE, SparseArrays
33
m,n = 2,3
44
p = 0.5
55

6-
# we rely on CUSPARSE conversions, so only test supported types
7-
for elty in [Float32, Float64]
8-
@testset "$typ" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC]
6+
for elty in [Int32, Int64, Float32, Float64]
7+
@testset "$typ($elty)" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC]
98
x = sprand(elty, m, n, p)
109
dx = typ(x)
1110

1211
# zero-preserving
13-
y = x .* 1
14-
dy = dx .* 1
12+
y = x .* elty(1)
13+
dy = dx .* elty(1)
1514
@test dy isa typ{elty}
1615
@test y == SparseMatrixCSC(dy)
1716

1817
# not zero-preserving
19-
y = x .+ 1
20-
dy = dx .+ 1
18+
y = x .+ elty(1)
19+
dy = dx .+ elty(1)
2120
@test dy isa CuArray{elty}
2221
@test y == Array(dy)
2322

2423
# involving something dense
25-
y = x .* ones(m, n)
26-
dy = dx .* CUDA.ones(m, n)
24+
y = x .* ones(elty, m, n)
25+
dy = dx .* CUDA.ones(elty, m, n)
2726
@test dy isa CuArray{elty}
2827
@test y == Array(dy)
2928

3029
# multiple inputs
3130
y = sprand(elty, m, n, p)
3231
dy = typ(y)
33-
z = x .* y .* 2
34-
dz = dx .* dy .* 2
32+
z = x .* y .* elty(2)
33+
dz = dx .* dy .* elty(2)
3534
@test dz isa typ{elty}
3635
@test z == SparseMatrixCSC(dz)
3736
end

0 commit comments

Comments
 (0)