Skip to content

Commit 9b327d6

Browse files
authored
Add propagate w_mul_xj CUDA sparse support using matrix mul (#610)
1 parent 0221593 commit 9b327d6

File tree

4 files changed

+19
-3
lines changed

4 files changed

+19
-3
lines changed

GNNlib/ext/GNNlibCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
## W_MUL_XJ
2727

2828
## avoid the fast path on gpu until we have better cuda support
29-
function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
29+
function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{COO_T}, ::typeof(+),
3030
xi, xj::AnyCuMatrix, e::Nothing)
3131
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
3232
end

GNNlib/src/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ end
233233
# for weighted convolution
234234
function propagate(::typeof(w_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix,
235235
e::Nothing)
236-
A = adjacency_matrix(g, weighted = true)
236+
A = adjacency_matrix(g, eltype(xj); weighted = true)
237237
return xj * A
238238
end
239239

GraphNeuralNetworks/perf/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@ GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
77
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
88
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
99
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
10-
Graphs = "093fc24a-ae57-5d10-9952-331d41423f4d"

GraphNeuralNetworks/perf/sparse_propagate_cuda.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,23 @@ function prop_copy_xj(graph_type, sp_p, n, feat_size)
3434
return nothing
3535
end
3636

37+
function prop_w_mul_xj(graph_type, sp_p, n, feat_size)
38+
A = sprand(n, n, sp_p)
39+
b = rand(1, n)
40+
B = rand(feat_size, n)
41+
g = GNNGraph(A,
42+
ndata = (; b = b, B = B),
43+
edata = (; A = reshape(A.nzval, 1, :)),
44+
graph_type = graph_type) |> dev
45+
printstyled("propagate w_mul_xj for graph type: $graph_type", "\n", color=:yellow)
46+
CUDA.@sync propagate(w_mul_xj, g, +; xj = g.ndata.B) # run once to compile before benchmarking
47+
@btime CUDA.@sync propagate($w_mul_xj, $g, +; xj = $g.ndata.B) # using spmm for :sparse
48+
printstyled("gather/scatter propagate w_mul_xj for graph type: $graph_type", "\n", color=:yellow)
49+
CUDA.@sync propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +; xj = g.ndata.B) # run once to compile before benchmarking
50+
@btime CUDA.@sync propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), $g, +; xj = $g.ndata.B) # using gather/scatter
51+
return nothing
52+
end
53+
3754
seed!(0)
3855
dev = gpu_device()
3956
println("Device: ", dev)

0 commit comments

Comments
 (0)