From d1f50b47c6674cbc187904ba8ee8e78e4fea03b6 Mon Sep 17 00:00:00 2001 From: dferre97 Date: Thu, 17 Jul 2025 10:11:53 +0200 Subject: [PATCH] Refactor propagate function signatures to accept subtype of COO_T for copy_xj and w_mul_xj --- GNNlib/ext/GNNlibCUDAExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GNNlib/ext/GNNlibCUDAExt.jl b/GNNlib/ext/GNNlibCUDAExt.jl index 78ab49262..56a6738e9 100644 --- a/GNNlib/ext/GNNlibCUDAExt.jl +++ b/GNNlib/ext/GNNlibCUDAExt.jl @@ -10,7 +10,7 @@ using GNNGraphs: GNNGraph, COO_T, SPARSE_T ## COPY_XJ ## avoid the fast path on gpu until we have better cuda support -function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{COO_T}, ::typeof(+), +function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:COO_T}, ::typeof(+), xi, xj::AnyCuMatrix, e) propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e) end @@ -26,7 +26,7 @@ end ## W_MUL_XJ ## avoid the fast path on gpu until we have better cuda support -function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{COO_T}, ::typeof(+), +function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:COO_T}, ::typeof(+), xi, xj::AnyCuMatrix, e::Nothing) propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) end