Skip to content

Commit ea6c42b

Browse files
authored
Refactor propagate function signatures to accept subtype of COO_T for copy_xj and w_mul_xj (#611)
1 parent 9b327d6 commit ea6c42b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

GNNlib/ext/GNNlibCUDAExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using GNNGraphs: GNNGraph, COO_T, SPARSE_T
1010
## COPY_XJ
1111

1212
## avoid the fast path on gpu until we have better cuda support
13-
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{COO_T}, ::typeof(+),
13+
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:COO_T}, ::typeof(+),
1414
xi, xj::AnyCuMatrix, e)
1515
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
1616
end
@@ -26,7 +26,7 @@ end
2626
## W_MUL_XJ
2727

2828
## avoid the fast path on gpu until we have better cuda support
29-
function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{COO_T}, ::typeof(+),
29+
function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:COO_T}, ::typeof(+),
3030
xi, xj::AnyCuMatrix, e::Nothing)
3131
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
3232
end

0 commit comments

Comments
 (0)