@@ -1214,6 +1214,63 @@ end
1214
1214
# # (GE) general matrix-matrix multiplication grouped batched
1215
1215
for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched , :cublasSgemmGroupedBatched_64 , :Float32 ),
1216
1216
(:cublasDgemmGroupedBatched , :cublasDgemmGroupedBatched_64 , :Float64 ))
1217
+ @eval begin
1218
+ function gemm_grouped_batched! (transA:: Vector{Char} ,
1219
+ transB:: Vector{Char} ,
1220
+ alpha:: Vector{$elty} ,
1221
+ A:: Vector{<:Vector{<:StridedCuMatrix{$elty}}} ,
1222
+ B:: Vector{<:Vector{<:StridedCuMatrix{$elty}}} ,
1223
+ beta:: Vector{$elty} ,
1224
+ C:: Vector{<:Vector{<:StridedCuMatrix{$elty}}} )
1225
+
1226
+ if length (A) != length (B) || length (A) != length (C)
1227
+ throw (DimensionMismatch (" A, B and C must contain the same number of groups" ))
1228
+ end
1229
+ group_count = length (A)
1230
+ for i= 1 : group_count
1231
+ if length (A[i]) != length (B[i]) || length (A[i]) != length (C[i])
1232
+ throw (DimensionMismatch (" A, B and C must contain the same number of matrices" ))
1233
+ end
1234
+ end
1235
+ group_size = length .(A)
1236
+
1237
+ for i = 1 : group_count
1238
+ m = size (A[i][1 ], transA[i] == ' N' ? 1 : 2 )
1239
+ k = size (A[i][1 ], transA[i] == ' N' ? 2 : 1 )
1240
+ n = size (B[i][1 ], transB[i] == ' N' ? 2 : 1 )
1241
+ if m != size (C[i][1 ],1 ) || n != size (C[i][1 ],2 ) || k != size (B[i][1 ], transB[i] == ' N' ? 1 : 2 )
1242
+ throw (DimensionMismatch (" " ))
1243
+ end
1244
+ end
1245
+
1246
+ transa = convert .(cublasOperation_t, transA)
1247
+ transb = convert .(cublasOperation_t, transB)
1248
+ m = [size (A[i][1 ], transA[i] == ' N' ? 1 : 2 ) for i = 1 : group_count]
1249
+ k = [size (A[i][1 ], transA[i] == ' N' ? 2 : 1 ) for i = 1 : group_count]
1250
+ n = [size (B[i][1 ], transB[i] == ' N' ? 2 : 1 ) for i = 1 : group_count]
1251
+ lda = [max (1 ,stride (A[i][1 ],2 )) for i = 1 : group_count]
1252
+ ldb = [max (1 ,stride (B[i][1 ],2 )) for i = 1 : group_count]
1253
+ ldc = [max (1 ,stride (C[i][1 ],2 )) for i = 1 : group_count]
1254
+ Aptrs = unsafe_batch (reduce (vcat, A))
1255
+ Bptrs = unsafe_batch (reduce (vcat, B))
1256
+ Cptrs = unsafe_batch (reduce (vcat, C))
1257
+
1258
+ if CUBLAS. version () >= v " 12.0"
1259
+ $ fname_64 (handle (), transa, transb, m, n, k, alpha, Aptrs, lda,
1260
+ Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
1261
+ else
1262
+ $ fname (handle (), transa, transb, m, n, k, alpha, Aptrs, lda,
1263
+ Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
1264
+ end
1265
+ unsafe_free! (Cptrs)
1266
+ unsafe_free! (Bptrs)
1267
+ unsafe_free! (Aptrs)
1268
+
1269
+ C
1270
+ end
1271
+ end
1272
+
1273
+ # Group size hardcoded to one
1217
1274
@eval begin
1218
1275
function gemm_grouped_batched! (transA:: Vector{Char} ,
1219
1276
transB:: Vector{Char} ,
@@ -1260,24 +1317,40 @@ for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGrouped
1260
1317
unsafe_free! (Cptrs)
1261
1318
unsafe_free! (Bptrs)
1262
1319
unsafe_free! (Aptrs)
1263
-
1264
1320
C
1265
1321
end
1266
1322
end
1267
1323
end
1268
1324
1269
1325
function gemm_grouped_batched (transA:: Vector{Char} , transB:: Vector{Char} , alpha:: Vector{T} ,
1270
- A:: Vector{<:StridedCuMatrix{T}} , B:: Vector{<:StridedCuMatrix{T}} ) where T
1271
- beta = [zero (T) for i = 1 : length (transA)]
1272
- C = CuMatrix{T}[similar (B[i], (size (A[i], transA[i] == ' N' ? 1 : 2 ), size (B[i], transB[i] == ' N' ? 2 : 1 ))) for i in 1 : length (A)]
1326
+ A:: Vector{<:Vector{<:StridedCuMatrix{T}}} , B:: Vector{<:Vector{<:StridedCuMatrix{T}}} ) where T
1327
+ num_groups = length (A)
1328
+ group_sizes = length .(A)
1329
+ beta = [zero (T) for i = 1 : num_groups]
1330
+ C = [[similar (B[i][j], (size (A[i][j], transA[i] == ' N' ? 1 : 2 ), size (B[i][j], transB[i] == ' N' ? 2 : 1 ))) for j in 1 : group_sizes[i]] for i in 1 : num_groups]
1273
1331
gemm_grouped_batched! (transA, transB, alpha, A, B, beta, C)
1274
1332
end
1333
+
1275
1334
function gemm_grouped_batched (transA:: Vector{Char} , transB:: Vector{Char} ,
1276
- A:: Vector{<:StridedCuMatrix{T}} , B:: Vector{<:StridedCuMatrix{T}} ) where T
1335
+ A:: Vector{<:Vector{<: StridedCuMatrix{T}}} , B:: Vector{<:Vector{<: StridedCuMatrix{T} }} ) where T
1277
1336
alpha = [one (T) for i = 1 : length (transA)]
1278
1337
gemm_grouped_batched (transA, transB, alpha, A, B)
1279
1338
end
1280
1339
1340
+ # Group size hardcoded to one
1341
+ function gemm_grouped_batched (transA:: Vector{Char} , transB:: Vector{Char} , alpha:: Vector{T} ,
1342
+ A:: Vector{<:StridedCuMatrix{T}} , B:: Vector{<:StridedCuMatrix{T}} ) where T
1343
+ beta = [zero (T) for i = 1 : length (transA)]
1344
+ C = CuMatrix{T}[similar (B[i], (size (A[i], transA[i] == ' N' ? 1 : 2 ), size (B[i], transB[i] == ' N' ? 2 : 1 ))) for i in 1 : length (A)]
1345
+ gemm_grouped_batched! (transA, transB, alpha, A, B, beta, C)
1346
+ end
1347
+
1348
+ function gemm_grouped_batched (transA:: Vector{Char} , transB:: Vector{Char} ,
1349
+ A:: Vector{<:StridedCuMatrix{T}} , B:: Vector{<:StridedCuMatrix{T}} ) where T
1350
+ alpha = [one (T) for i = 1 : length (transA)]
1351
+ gemm_grouped_batched (transA, transB, alpha, A, B)
1352
+ end
1353
+
1281
1354
# # (GE) general matrix-matrix multiplication batched
1282
1355
for (fname, fname_64, elty) in ((:cublasDgemmBatched , :cublasDgemmBatched_64 , :Float64 ),
1283
1356
(:cublasSgemmBatched , :cublasSgemmBatched_64 , :Float32 ),
0 commit comments