@@ -887,7 +887,7 @@ function gemmEx!(transA::Char, transB::Char,
887
887
k = size (A, transA == ' N' ? 2 : 1 )
888
888
n = size (B, transB == ' N' ? 2 : 1 )
889
889
if m != size (C,1 ) || n != size (C,2 ) || k != size (B, transB == ' N' ? 1 : 2 )
890
- throw (DimensionMismatch (" " ))
890
+ throw (DimensionMismatch (" A has dimension $( size (A)) , B has dimension $( size (B)) and C has dimension $( size (C)) " ))
891
891
end
892
892
lda = max (1 ,stride (A,2 ))
893
893
ldb = max (1 ,stride (B,2 ))
@@ -909,6 +909,91 @@ function gemmEx!(transA::Char, transB::Char,
909
909
C
910
910
end
911
911
912
+ function gemmBatchedEx! (transA:: Char , transB:: Char ,
913
+ @nospecialize (alpha:: Number ),
914
+ @nospecialize (A:: Vector{<:StridedCuVecOrMat} ),
915
+ @nospecialize (B:: Vector{<:StridedCuVecOrMat} ),
916
+ @nospecialize (beta:: Number ),
917
+ @nospecialize (C:: Vector{<:StridedCuVecOrMat} );
918
+ algo:: cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT)
919
+ if length (A) != length (B) || length (A) != length (C)
920
+ throw (DimensionMismatch (" Lengths of inputs must be the same" ))
921
+ end
922
+ for (i, (As,Bs,Cs)) in enumerate (zip (A,B,C))
923
+ m = size (As, transA == ' N' ? 1 : 2 )
924
+ k = size (As, transA == ' N' ? 2 : 1 )
925
+ n = size (Bs, transB == ' N' ? 2 : 1 )
926
+ if m != size (Cs,1 ) || n != size (Cs,2 ) || k != size (Bs, transB == ' N' ? 1 : 2 )
927
+ throw (DimensionMismatch (" Input $i : A has dimension $(size (As)) , B has dimension $(size (Bs)) , C has dimension $(size (Cs)) " ))
928
+ end
929
+ end
930
+ m = size (A[1 ], transA == ' N' ? 1 : 2 )
931
+ k = size (A[1 ], transA == ' N' ? 2 : 1 )
932
+ n = size (B[1 ], transB == ' N' ? 2 : 1 )
933
+ lda = max (1 ,stride (A[1 ],2 ))
934
+ ldb = max (1 ,stride (B[1 ],2 ))
935
+ ldc = max (1 ,stride (C[1 ],2 ))
936
+ computeType = gemmExComputeType (eltype (A[1 ]), eltype (B[1 ]), eltype (C[1 ]), m, k, n)
937
+ isnothing (computeType) &&
938
+ throw (ArgumentError (" gemmEx does not support $(eltype (C)) =$(eltype (A)) *$(eltype (B)) " ))
939
+ computeT = juliaStorageType (eltype (C[1 ]), computeType)
940
+ Aptrs = unsafe_batch (A)
941
+ Bptrs = unsafe_batch (B)
942
+ Cptrs = unsafe_batch (C)
943
+ if version () >= v " 11.0"
944
+ # with CUDA 11, the compute type encodes the math mode.
945
+ cublasGemmBatchedEx (handle (), transA, transB, m, n, k, Ref {computeT} (alpha), Aptrs, eltype (A[1 ]), lda, Bptrs,
946
+ eltype (B[1 ]), ldb, Ref {computeT} (beta), Cptrs, eltype (C[1 ]), ldc, length (A), computeType, algo)
947
+ else
948
+ error (" Not implemented for CUDA 11 and below." )
949
+ end
950
+ unsafe_free! (Cptrs)
951
+ unsafe_free! (Bptrs)
952
+ unsafe_free! (Aptrs)
953
+
954
+ C
955
+ end
956
+
957
+ function gemmStridedBatchedEx! (transA:: Char , transB:: Char ,
958
+ @nospecialize (alpha:: Number ),
959
+ @nospecialize (A:: AbstractArray{Ta, 3} ),
960
+ @nospecialize (B:: AbstractArray{Tb, 3} ),
961
+ @nospecialize (beta:: Number ),
962
+ @nospecialize (C:: AbstractArray{Tc, 3} );
963
+ algo:: cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT) where {Ta, Tb, Tc}
964
+ if size (A, 3 ) != size (B, 3 ) || size (A, 3 ) != size (C, 3 )
965
+ throw (DimensionMismatch (" Batch sizes must be equal for all inputs" ))
966
+ end
967
+ m = size (A, transA == ' N' ? 1 : 2 )
968
+ k = size (A, transA == ' N' ? 2 : 1 )
969
+ n = size (B, transB == ' N' ? 2 : 1 )
970
+ if m != size (C,1 ) || n != size (C,2 ) || k != size (B, transB == ' N' ? 1 : 2 )
971
+ throw (DimensionMismatch (" A has dimension $(size (A)) , B has dimension $(size (B)) , C has dimension $(size (C)) " ))
972
+ end
973
+ lda = max (1 ,stride (A,2 ))
974
+ ldb = max (1 ,stride (B,2 ))
975
+ ldc = max (1 ,stride (C,2 ))
976
+
977
+ strideA = size (A, 3 ) == 1 ? 0 : stride (A, 3 )
978
+ strideB = size (B, 3 ) == 1 ? 0 : stride (B, 3 )
979
+ strideC = stride (C, 3 )
980
+ batchCount = size (C, 3 )
981
+
982
+ computeType = gemmExComputeType (eltype (A), eltype (B), eltype (C), m, k, n)
983
+ isnothing (computeType) &&
984
+ throw (ArgumentError (" gemmEx does not support $(eltype (C)) =$(eltype (A)) *$(eltype (B)) " ))
985
+ computeT = juliaStorageType (eltype (C), computeType)
986
+ if version () >= v " 11.0"
987
+ # with CUDA 11, the compute type encodes the math mode.
988
+ cublasGemmStridedBatchedEx (handle (), transA, transB, m, n, k, Ref {computeT} (alpha), A, eltype (A), lda, strideA,
989
+ B, eltype (B), ldb, strideB, Ref {computeT} (beta), C, eltype (C), ldc, strideC,
990
+ batchCount, computeType, algo)
991
+ else
992
+ error (" Not implemented for CUDA 11 and below." )
993
+ end
994
+ C
995
+ end
996
+
912
997
# create a batch of pointers in device memory from a batch of device arrays
913
998
@inline function unsafe_batch (batch:: Vector{<:CuArray{T}} ) where {T}
914
999
ptrs = pointer .(batch)
@@ -969,6 +1054,7 @@ for (fname, elty) in
969
1054
end
970
1055
end
971
1056
end
1057
+
972
1058
function gemm_batched (transA:: Char , transB:: Char , alpha:: Number ,
973
1059
A:: Vector{<:StridedCuMatrix{T}} , B:: Vector{<:StridedCuMatrix{T}} ) where T
974
1060
C = CuMatrix{T}[similar (B[1 ], (size (A[1 ], transA == ' N' ? 1 : 2 ),size (B[1 ], transB == ' N' ? 2 : 1 ))) for i in 1 : length (A)]
0 commit comments