@@ -324,6 +324,7 @@ function LinearAlgebra.generic_matmatmul!(C::CuVecOrMat, tA, tB, A::StridedCuVec
324
324
return hemm! (' R' , tB == ' H' ? ' U' : ' L' , alpha, B, A, beta, C)
325
325
end
326
326
end
327
+
327
328
GPUArrays. generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
328
329
end
329
330
@@ -691,13 +692,13 @@ function LinearAlgebra.kron!(C::CuMatrix{TC}, A::CuMatrix{TA}, B::CuMatrix{TB})
691
692
function _kron_mat_kernelA! (C, A, B, m, n, p, q)
692
693
index_i = (blockIdx (). x - 1 ) * blockDim (). x + threadIdx (). x
693
694
index_j = (blockIdx (). y - 1 ) * blockDim (). y + threadIdx (). y
694
-
695
+
695
696
stride_i = blockDim (). x * gridDim (). x
696
697
stride_j = blockDim (). y * gridDim (). y
697
-
698
+
698
699
index_i > m && return
699
700
index_j > n && return
700
-
701
+
701
702
for i in index_i: stride_i: m
702
703
for j in index_j: stride_j: n
703
704
for k in 1 : p
@@ -713,13 +714,13 @@ function LinearAlgebra.kron!(C::CuMatrix{TC}, A::CuMatrix{TA}, B::CuMatrix{TB})
713
714
function _kron_mat_kernelB! (C, A, B, m, n, p, q)
714
715
index_p = (blockIdx (). x - 1 ) * blockDim (). x + threadIdx (). x
715
716
index_q = (blockIdx (). y - 1 ) * blockDim (). y + threadIdx (). y
716
-
717
+
717
718
stride_p = blockDim (). x * gridDim (). x
718
719
stride_q = blockDim (). y * gridDim (). y
719
-
720
+
720
721
index_p > p && return
721
722
index_q > q && return
722
-
723
+
723
724
for i in 1 : m
724
725
for j in 1 : n
725
726
for k in index_p: stride_p: p
@@ -737,7 +738,7 @@ function LinearAlgebra.kron!(C::CuMatrix{TC}, A::CuMatrix{TA}, B::CuMatrix{TB})
737
738
738
739
# Use different kernels depending on the size of the matrices
739
740
# choosing to parallelize the matrix with the largest number of elements
740
- m* n >= p* q ? (kernel = @cuda launch= false _kron_mat_kernelA! (C, A, B, m, n, p, q)) :
741
+ m* n >= p* q ? (kernel = @cuda launch= false _kron_mat_kernelA! (C, A, B, m, n, p, q)) :
741
742
(kernel = @cuda launch= false _kron_mat_kernelB! (C, A, B, m, n, p, q))
742
743
743
744
m* n >= p* q ? (sizes = (m, n)) : (sizes = (p, q))
764
765
function LinearAlgebra. kron (A:: CuMatrix{TA} , B:: CuMatrix{TB} ) where {TA,TB}
765
766
m, n = size (A)
766
767
p, q = size (B)
767
-
768
+
768
769
T = promote_type (TA, TB)
769
770
C = similar (A, T, m* p, n* q)
770
771
771
772
kron! (C, A, B)
772
- end
773
+ end
0 commit comments