Skip to content

Commit abd569e

Browse files
lpawelamaleadt
andauthored
Add wrapper for gemmBatchedEx! (#1975)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 315c80e commit abd569e

File tree

2 files changed

+111
-3
lines changed

2 files changed

+111
-3
lines changed

lib/cublas/wrappers.jl

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ function gemmEx!(transA::Char, transB::Char,
887887
k = size(A, transA == 'N' ? 2 : 1)
888888
n = size(B, transB == 'N' ? 2 : 1)
889889
if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2)
890-
throw(DimensionMismatch(""))
890+
throw(DimensionMismatch("A has dimension $(size(A)), B has dimension $(size(B)) and C has dimension $(size(C))"))
891891
end
892892
lda = max(1,stride(A,2))
893893
ldb = max(1,stride(B,2))
@@ -909,6 +909,91 @@ function gemmEx!(transA::Char, transB::Char,
909909
C
910910
end
911911

912+
function gemmBatchedEx!(transA::Char, transB::Char,
913+
@nospecialize(alpha::Number),
914+
@nospecialize(A::Vector{<:StridedCuVecOrMat}),
915+
@nospecialize(B::Vector{<:StridedCuVecOrMat}),
916+
@nospecialize(beta::Number),
917+
@nospecialize(C::Vector{<:StridedCuVecOrMat});
918+
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
919+
if length(A) != length(B) || length(A) != length(C)
920+
throw(DimensionMismatch("Lengths of inputs must be the same"))
921+
end
922+
for (i, (As,Bs,Cs)) in enumerate(zip(A,B,C))
923+
m = size(As, transA == 'N' ? 1 : 2)
924+
k = size(As, transA == 'N' ? 2 : 1)
925+
n = size(Bs, transB == 'N' ? 2 : 1)
926+
if m != size(Cs,1) || n != size(Cs,2) || k != size(Bs, transB == 'N' ? 1 : 2)
927+
throw(DimensionMismatch("Input $i: A has dimension $(size(As)), B has dimension $(size(Bs)), C has dimension $(size(Cs))"))
928+
end
929+
end
930+
m = size(A[1], transA == 'N' ? 1 : 2)
931+
k = size(A[1], transA == 'N' ? 2 : 1)
932+
n = size(B[1], transB == 'N' ? 2 : 1)
933+
lda = max(1,stride(A[1],2))
934+
ldb = max(1,stride(B[1],2))
935+
ldc = max(1,stride(C[1],2))
936+
computeType = gemmExComputeType(eltype(A[1]), eltype(B[1]), eltype(C[1]), m, k, n)
937+
isnothing(computeType) &&
938+
throw(ArgumentError("gemmEx does not support $(eltype(C))=$(eltype(A))*$(eltype(B))"))
939+
computeT = juliaStorageType(eltype(C[1]), computeType)
940+
Aptrs = unsafe_batch(A)
941+
Bptrs = unsafe_batch(B)
942+
Cptrs = unsafe_batch(C)
943+
if version() >= v"11.0"
944+
# with CUDA 11, the compute type encodes the math mode.
945+
cublasGemmBatchedEx(handle(), transA, transB, m, n, k, Ref{computeT}(alpha), Aptrs, eltype(A[1]), lda, Bptrs,
946+
eltype(B[1]), ldb, Ref{computeT}(beta), Cptrs, eltype(C[1]), ldc, length(A), computeType, algo)
947+
else
948+
error("Not implemented for CUDA 11 and below.")
949+
end
950+
unsafe_free!(Cptrs)
951+
unsafe_free!(Bptrs)
952+
unsafe_free!(Aptrs)
953+
954+
C
955+
end
956+
957+
function gemmStridedBatchedEx!(transA::Char, transB::Char,
958+
@nospecialize(alpha::Number),
959+
@nospecialize(A::AbstractArray{Ta, 3}),
960+
@nospecialize(B::AbstractArray{Tb, 3}),
961+
@nospecialize(beta::Number),
962+
@nospecialize(C::AbstractArray{Tc, 3});
963+
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT) where {Ta, Tb, Tc}
964+
if size(A, 3) != size(B, 3) || size(A, 3) != size(C, 3)
965+
throw(DimensionMismatch("Batch sizes must be equal for all inputs"))
966+
end
967+
m = size(A, transA == 'N' ? 1 : 2)
968+
k = size(A, transA == 'N' ? 2 : 1)
969+
n = size(B, transB == 'N' ? 2 : 1)
970+
if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2)
971+
throw(DimensionMismatch("A has dimension $(size(A)), B has dimension $(size(B)), C has dimension $(size(C))"))
972+
end
973+
lda = max(1,stride(A,2))
974+
ldb = max(1,stride(B,2))
975+
ldc = max(1,stride(C,2))
976+
977+
strideA = size(A, 3) == 1 ? 0 : stride(A, 3)
978+
strideB = size(B, 3) == 1 ? 0 : stride(B, 3)
979+
strideC = stride(C, 3)
980+
batchCount = size(C, 3)
981+
982+
computeType = gemmExComputeType(eltype(A), eltype(B), eltype(C), m, k, n)
983+
isnothing(computeType) &&
984+
throw(ArgumentError("gemmEx does not support $(eltype(C))=$(eltype(A))*$(eltype(B))"))
985+
computeT = juliaStorageType(eltype(C), computeType)
986+
if version() >= v"11.0"
987+
# with CUDA 11, the compute type encodes the math mode.
988+
cublasGemmStridedBatchedEx(handle(), transA, transB, m, n, k, Ref{computeT}(alpha), A, eltype(A), lda, strideA,
989+
B, eltype(B), ldb, strideB, Ref{computeT}(beta), C, eltype(C), ldc, strideC,
990+
batchCount, computeType, algo)
991+
else
992+
error("Not implemented for CUDA 11 and below.")
993+
end
994+
C
995+
end
996+
912997
# create a batch of pointers in device memory from a batch of device arrays
913998
@inline function unsafe_batch(batch::Vector{<:CuArray{T}}) where {T}
914999
ptrs = pointer.(batch)
@@ -969,6 +1054,7 @@ for (fname, elty) in
9691054
end
9701055
end
9711056
end
1057+
9721058
function gemm_batched(transA::Char, transB::Char, alpha::Number,
9731059
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
9741060
C = CuMatrix{T}[similar(B[1], (size(A[1], transA == 'N' ? 1 : 2),size(B[1], transB == 'N' ? 2 : 1))) for i in 1:length(A)]

test/libraries/cublas.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,13 +1573,25 @@ end
15731573
@testset "gemm_batched" begin
15741574
bd_C = CUBLAS.gemm_batched('N','N',bd_A,bd_B)
15751575
for i in 1:length(bA)
1576-
bC = bA[i]*bB[i]
1576+
bC[i] = bA[i]*bB[i]
15771577
h_C = Array(bd_C[i])
1578-
@test bC h_C
1578+
@test bC[i] h_C
15791579
end
15801580
@test_throws DimensionMismatch CUBLAS.gemm_batched('N','N',alpha,bd_A,bd_bad)
15811581
end
15821582

1583+
@testset "gemmBatchedEx!" begin
1584+
# C = (alpha*A)*B + beta*C
1585+
CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A,bd_B,beta,bd_C)
1586+
for i in 1:length(bd_C)
1587+
bC[i] = (alpha*bA[i])*bB[i] + beta*bC[i]
1588+
h_C = Array(bd_C[i])
1589+
#compare
1590+
@test bC[i] h_C
1591+
end
1592+
@test_throws DimensionMismatch CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A,bd_bad,beta,bd_C)
1593+
end
1594+
15831595
nbatch = 10
15841596
bA = rand(elty, m, k, nbatch)
15851597
bB = rand(elty, k, n, nbatch)
@@ -1601,6 +1613,16 @@ end
16011613
@test_throws DimensionMismatch CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
16021614
end
16031615

1616+
@testset "gemmStridedBatchedEx!" begin
1617+
CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, bd_B, beta, bd_C)
1618+
for i in 1:nbatch
1619+
bC[:, :, i] = (alpha * bA[:, :, i]) * bB[:, :, i] + beta * bC[:, :, i]
1620+
end
1621+
h_C = Array(bd_C)
1622+
@test bC h_C
1623+
@test_throws DimensionMismatch CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
1624+
end
1625+
16041626
@testset "gemm_strided_batched" begin
16051627
bd_C = CUBLAS.gemm_strided_batched('N', 'N', bd_A, bd_B)
16061628

0 commit comments

Comments
 (0)