diff --git a/GNNlib/ext/GNNlibCUDAExt.jl b/GNNlib/ext/GNNlibCUDAExt.jl index 78ab49262..56a6738e9 100644 --- a/GNNlib/ext/GNNlibCUDAExt.jl +++ b/GNNlib/ext/GNNlibCUDAExt.jl @@ -10,7 +10,7 @@ using GNNGraphs: GNNGraph, COO_T, SPARSE_T ## COPY_XJ ## avoid the fast path on gpu until we have better cuda support -function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{COO_T}, ::typeof(+), +function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:COO_T}, ::typeof(+), xi, xj::AnyCuMatrix, e) propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e) end @@ -26,7 +26,7 @@ end ## W_MUL_XJ ## avoid the fast path on gpu until we have better cuda support -function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{COO_T}, ::typeof(+), +function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:COO_T}, ::typeof(+), xi, xj::AnyCuMatrix, e::Nothing) propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) end