Skip to content

Commit c707e2e

Browse files
Provide CUDA support for the binarize() operation on sparse matrices (#601)
* Add binarize function for CuSparseMatrixCSC and update its usage * Refactor CUDA binarize function to use Bool as the CPU version * Revert unrelated changes, make coherent PR on binarize() and nothing else * Removed debug prints Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com> * Unify binarize function signature * Remove debug print statement from binarize function --------- Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
1 parent 4a426fe commit c707e2e

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

GNNGraphs/ext/GNNGraphsCUDAExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using CUDA
44
using Random, Statistics, LinearAlgebra
55
using GNNGraphs
66
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
7+
using SparseArrays
78

89
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
910

@@ -20,6 +21,11 @@ GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz
2021

2122
GNNGraphs.iscuarray(x::AnyCuArray) = true
2223

24+
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC, T::DataType = Bool)
25+
bin_vals = fill!(similar(nonzeros(Mat)), one(T))
26+
return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat))
27+
end
28+
2329

2430
function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
2531
dev = get_device(u)

GNNGraphs/src/query.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
235235
@assert dir [:in, :out]
236236
A = g.graph
237237
if !weighted
238-
A = binarize(A)
238+
A = binarize(A, T)
239239
end
240240
A = T != eltype(A) ? T.(A) : A
241241
return dir == :out ? A : A'
@@ -377,7 +377,7 @@ end
377377

378378
function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num_nodes::Int)
379379
if edge_weight === false
380-
A = binarize(A)
380+
A = binarize(A, T)
381381
end
382382
A = eltype(A) != T ? T.(A) : A
383383
return dir == :out ? vec(sum(A, dims = 2)) :

GNNGraphs/src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ function _rand_edges(rng, (n1, n2), m)
295295
return s, t, val
296296
end
297297

298-
binarize(x) = map(>(0), x)
298+
binarize(x, T::DataType = Bool) = ifelse.(x .> 0, one(T), zero(T))
299299

300300
CRC.@non_differentiable binarize(x...)
301301
CRC.@non_differentiable edge_encoding(x...)

0 commit comments

Comments
 (0)