Skip to content

Commit bd4f6aa

Browse files
committed
Fix sparse COO to CSR conversion.
1 parent dcbab7d commit bd4f6aa

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

lib/cusparse/conversions.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -349,23 +349,23 @@ end
349349

350350
## CSR to COO and vice-versa
351351

352+
# TODO: we can do similar for CSC conversions, but that requires the columns to be sorted
353+
352354
function CuSparseMatrixCSR(coo::CuSparseMatrixCOO{Tv}, ind::SparseChar='O') where {Tv}
353355
m,n = size(coo)
354-
nnz = coo.nnz
355-
csrRowPtr = CUDA.zeros(Cint, nnz)
356-
cusparseXcoo2csr(handle(), coo.rowInd, nnz, m, csrRowPtr, ind)
357-
CuSparseMatrixCSR{Tv}(csrRowPtr, coo.colInd, coo.nzVal, size(coo))
356+
csrRowPtr = CuVector{Cint}(undef, m+1)
357+
cusparseXcoo2csr(handle(), coo.rowInd, nnz(coo), m, csrRowPtr, ind)
358+
CuSparseMatrixCSR{Tv}(csrRowPtr, coo.colInd, nonzeros(coo), size(coo))
358359
end
359360

360361
function CuSparseMatrixCOO(csr::CuSparseMatrixCSR{Tv}, ind::SparseChar='O') where {Tv}
361362
m,n = size(csr)
362-
nnz = csr.nnz
363-
cooRowInd = CUDA.zeros(Cint, nnz)
364-
cusparseXcsr2coo(handle(), csr.rowPtr, nnz, m, cooRowInd, ind)
365-
CuSparseMatrixCOO{Tv}(cooRowInd, csr.colVal, csr.nzVal, size(csr), nnz)
363+
cooRowInd = CuVector{Cint}(undef, nnz(csr))
364+
cusparseXcsr2coo(handle(), csr.rowPtr, nnz(csr), m, cooRowInd, ind)
365+
CuSparseMatrixCOO{Tv}(cooRowInd, csr.colVal, nonzeros(csr), size(csr), nnz(csr))
366366
end
367367

368-
### CSC/BST to COO and viceversa
368+
### CSC/BSR to COO and viceversa
369369

370370
CuSparseMatrixCSC(coo::CuSparseMatrixCOO) = CuSparseMatrixCSC(CuSparseMatrixCSR(coo)) # no direct conversion
371371
CuSparseMatrixCOO(csc::CuSparseMatrixCSC) = CuSparseMatrixCOO(CuSparseMatrixCSR(csc)) # no direct conversion

0 commit comments

Comments
 (0)