Skip to content

Commit beccab1

Browse files
authored
Add support for arbitrary group sizes in gemm_grouped_batched! (#2334)
1 parent 1bdbb86 commit beccab1

File tree

2 files changed

+116
-5
lines changed

2 files changed

+116
-5
lines changed

lib/cublas/wrappers.jl

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,63 @@ end
12141214
## (GE) general matrix-matrix multiplication grouped batched
12151215
for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGroupedBatched_64, :Float32),
12161216
(:cublasDgemmGroupedBatched, :cublasDgemmGroupedBatched_64, :Float64))
1217+
@eval begin
1218+
function gemm_grouped_batched!(transA::Vector{Char},
1219+
transB::Vector{Char},
1220+
alpha::Vector{$elty},
1221+
A::Vector{<:Vector{<:StridedCuMatrix{$elty}}},
1222+
B::Vector{<:Vector{<:StridedCuMatrix{$elty}}},
1223+
beta::Vector{$elty},
1224+
C::Vector{<:Vector{<:StridedCuMatrix{$elty}}})
1225+
1226+
if length(A) != length(B) || length(A) != length(C)
1227+
throw(DimensionMismatch("A, B and C must contain the same number of groups"))
1228+
end
1229+
group_count = length(A)
1230+
for i=1:group_count
1231+
if length(A[i]) != length(B[i]) || length(A[i]) != length(C[i])
1232+
throw(DimensionMismatch("A, B and C must contain the same number of matrices"))
1233+
end
1234+
end
1235+
group_size = length.(A)
1236+
1237+
for i = 1:group_count
1238+
m = size(A[i][1], transA[i] == 'N' ? 1 : 2)
1239+
k = size(A[i][1], transA[i] == 'N' ? 2 : 1)
1240+
n = size(B[i][1], transB[i] == 'N' ? 2 : 1)
1241+
if m != size(C[i][1],1) || n != size(C[i][1],2) || k != size(B[i][1], transB[i] == 'N' ? 1 : 2)
1242+
throw(DimensionMismatch(""))
1243+
end
1244+
end
1245+
1246+
transa = convert.(cublasOperation_t, transA)
1247+
transb = convert.(cublasOperation_t, transB)
1248+
m = [size(A[i][1], transA[i] == 'N' ? 1 : 2) for i = 1 : group_count]
1249+
k = [size(A[i][1], transA[i] == 'N' ? 2 : 1) for i = 1 : group_count]
1250+
n = [size(B[i][1], transB[i] == 'N' ? 2 : 1) for i = 1 : group_count]
1251+
lda = [max(1,stride(A[i][1],2)) for i = 1 : group_count]
1252+
ldb = [max(1,stride(B[i][1],2)) for i = 1 : group_count]
1253+
ldc = [max(1,stride(C[i][1],2)) for i = 1 : group_count]
1254+
Aptrs = unsafe_batch(reduce(vcat, A))
1255+
Bptrs = unsafe_batch(reduce(vcat, B))
1256+
Cptrs = unsafe_batch(reduce(vcat, C))
1257+
1258+
if CUBLAS.version() >= v"12.0"
1259+
$fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
1260+
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
1261+
else
1262+
$fname(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
1263+
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
1264+
end
1265+
unsafe_free!(Cptrs)
1266+
unsafe_free!(Bptrs)
1267+
unsafe_free!(Aptrs)
1268+
1269+
C
1270+
end
1271+
end
1272+
1273+
# Group size hardcoded to one
12171274
@eval begin
12181275
function gemm_grouped_batched!(transA::Vector{Char},
12191276
transB::Vector{Char},
@@ -1260,24 +1317,40 @@ for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGrouped
12601317
unsafe_free!(Cptrs)
12611318
unsafe_free!(Bptrs)
12621319
unsafe_free!(Aptrs)
1263-
12641320
C
12651321
end
12661322
end
12671323
end
12681324

12691325
function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char}, alpha::Vector{T},
1270-
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
1271-
beta = [zero(T) for i = 1:length(transA)]
1272-
C = CuMatrix{T}[similar(B[i], (size(A[i], transA[i] == 'N' ? 1 : 2), size(B[i], transB[i] == 'N' ? 2 : 1))) for i in 1:length(A)]
1326+
A::Vector{<:Vector{<:StridedCuMatrix{T}}}, B::Vector{<:Vector{<:StridedCuMatrix{T}}}) where T
1327+
num_groups = length(A)
1328+
group_sizes = length.(A)
1329+
beta = [zero(T) for i = 1:num_groups]
1330+
C = [[similar(B[i][j], (size(A[i][j], transA[i] == 'N' ? 1 : 2), size(B[i][j], transB[i] == 'N' ? 2 : 1))) for j in 1:group_sizes[i]] for i in 1:num_groups]
12731331
gemm_grouped_batched!(transA, transB, alpha, A, B, beta, C)
12741332
end
1333+
12751334
function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char},
1276-
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
1335+
A::Vector{<:Vector{<:StridedCuMatrix{T}}}, B::Vector{<:Vector{<:StridedCuMatrix{T}}}) where T
12771336
alpha = [one(T) for i = 1:length(transA)]
12781337
gemm_grouped_batched(transA, transB, alpha, A, B)
12791338
end
12801339

1340+
# Group size hardcoded to one
1341+
function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char}, alpha::Vector{T},
1342+
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
1343+
beta = [zero(T) for i = 1:length(transA)]
1344+
C = CuMatrix{T}[similar(B[i], (size(A[i], transA[i] == 'N' ? 1 : 2), size(B[i], transB[i] == 'N' ? 2 : 1))) for i in 1:length(A)]
1345+
gemm_grouped_batched!(transA, transB, alpha, A, B, beta, C)
1346+
end
1347+
1348+
function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char},
1349+
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
1350+
alpha = [one(T) for i = 1:length(transA)]
1351+
gemm_grouped_batched(transA, transB, alpha, A, B)
1352+
end
1353+
12811354
## (GE) general matrix-matrix multiplication batched
12821355
for (fname, fname_64, elty) in ((:cublasDgemmBatched, :cublasDgemmBatched_64, :Float64),
12831356
(:cublasSgemmBatched, :cublasSgemmBatched_64, :Float32),

test/libraries/cublas.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,44 @@ end
17111711
end
17121712
end
17131713

1714+
if CUDA.CUBLAS.version() >= v"12.4.2"
1715+
@testset "elty = $elty" for elty in [Float32, Float64]
1716+
num_groups = 10
1717+
group_sizes = collect(1:num_groups)
1718+
transA = ['N' for i in 1:num_groups]
1719+
transB = ['N' for i in 1:num_groups]
1720+
alpha = rand(elty, num_groups)
1721+
beta = rand(elty, num_groups)
1722+
# generate matrices
1723+
bA = [[rand(elty,3*i,2*i) for j in 1:group_sizes[i]] for i in 1:num_groups]
1724+
bB = [[rand(elty,2*i,5*i) for j in 1:group_sizes[i]] for i in 1:num_groups]
1725+
bC = [[rand(elty,3*i,5*i) for j in 1:group_sizes[i]] for i in 1:num_groups]
1726+
# move to device
1727+
bd_A = [[CuArray(bA[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups]
1728+
bd_B = [[CuArray(bB[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups]
1729+
bd_C = [[CuArray(bC[i][j]) for j in 1:group_sizes[i]] for i in 1:num_groups]
1730+
@testset "gemm_grouped_batched!" begin
1731+
# C = (alpha*A)*B + beta*C
1732+
CUBLAS.gemm_grouped_batched!(transA,transB,alpha,bd_A,bd_B,beta,bd_C)
1733+
for i in 1:num_groups, j in 1:group_sizes[i]
1734+
bC[i][j] = alpha[i] * bA[i][j] * bB[i][j] + beta[i] * bC[i][j]
1735+
h_C = Array(bd_C[i][j])
1736+
@test bC[i][j] h_C
1737+
end
1738+
end
1739+
1740+
@testset "gemm_grouped_batched" begin
1741+
bd_C = CUBLAS.gemm_grouped_batched(transA,transB,bd_A,bd_B)
1742+
for i in 1:num_groups, j in 1:group_sizes[i]
1743+
bC[i][j] = bA[i][j] * bB[i][j]
1744+
h_C = Array(bd_C[i][j])
1745+
@test bC[i][j] h_C
1746+
end
1747+
end
1748+
end
1749+
end
1750+
1751+
# Group size hardcoded to one
17141752
if CUDA.CUBLAS.version() >= v"12.4.2"
17151753
@testset "elty = $elty" for elty in [Float32, Float64]
17161754

0 commit comments

Comments
 (0)