Skip to content

Commit 7f725c0

Browse files
authored
[CUBLAS] Interface gemm_grouped_batched (#2310)
Only on CUDA 12.4+
1 parent 3561f73 commit 7f725c0

File tree

4 files changed

+146
-55
lines changed

4 files changed

+146
-55
lines changed

lib/cublas/libcublas.jl

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4657,19 +4657,19 @@ end
46574657
@gcsafe_ccall libcublas.cublasSgemmGroupedBatched(handle::cublasHandle_t,
46584658
transa_array::Ptr{cublasOperation_t},
46594659
transb_array::Ptr{cublasOperation_t},
4660-
m_array::CuPtr{Cint},
4661-
n_array::CuPtr{Cint},
4662-
k_array::CuPtr{Cint},
4663-
alpha_array::CuPtr{Float32},
4660+
m_array::Ptr{Cint},
4661+
n_array::Ptr{Cint},
4662+
k_array::Ptr{Cint},
4663+
alpha_array::Ptr{Float32},
46644664
Aarray::CuPtr{Ptr{Float32}},
4665-
lda_array::CuPtr{Cint},
4665+
lda_array::Ptr{Cint},
46664666
Barray::CuPtr{Ptr{Float32}},
4667-
ldb_array::CuPtr{Cint},
4668-
beta_array::CuPtr{Float32},
4667+
ldb_array::Ptr{Cint},
4668+
beta_array::Ptr{Float32},
46694669
Carray::CuPtr{Ptr{Float32}},
4670-
ldc_array::CuPtr{Cint},
4670+
ldc_array::Ptr{Cint},
46714671
group_count::Cint,
4672-
group_size::CuPtr{Cint})::cublasStatus_t
4672+
group_size::Ptr{Cint})::cublasStatus_t
46734673
end
46744674

46754675
@checked function cublasSgemmGroupedBatched_64(handle, transa_array, transb_array, m_array,
@@ -4680,19 +4680,19 @@ end
46804680
@gcsafe_ccall libcublas.cublasSgemmGroupedBatched_64(handle::cublasHandle_t,
46814681
transa_array::Ptr{cublasOperation_t},
46824682
transb_array::Ptr{cublasOperation_t},
4683-
m_array::CuPtr{Int64},
4684-
n_array::CuPtr{Int64},
4685-
k_array::CuPtr{Int64},
4686-
alpha_array::CuPtr{Float32},
4683+
m_array::Ptr{Int64},
4684+
n_array::Ptr{Int64},
4685+
k_array::Ptr{Int64},
4686+
alpha_array::Ptr{Float32},
46874687
Aarray::CuPtr{Ptr{Float32}},
4688-
lda_array::CuPtr{Int64},
4688+
lda_array::Ptr{Int64},
46894689
Barray::CuPtr{Ptr{Float32}},
4690-
ldb_array::CuPtr{Int64},
4691-
beta_array::CuPtr{Float32},
4690+
ldb_array::Ptr{Int64},
4691+
beta_array::Ptr{Float32},
46924692
Carray::CuPtr{Ptr{Float32}},
4693-
ldc_array::CuPtr{Int64},
4693+
ldc_array::Ptr{Int64},
46944694
group_count::Int64,
4695-
group_size::CuPtr{Int64})::cublasStatus_t
4695+
group_size::Ptr{Int64})::cublasStatus_t
46964696
end
46974697

46984698
@checked function cublasDgemmGroupedBatched(handle, transa_array, transb_array, m_array,
@@ -4703,19 +4703,19 @@ end
47034703
@gcsafe_ccall libcublas.cublasDgemmGroupedBatched(handle::cublasHandle_t,
47044704
transa_array::Ptr{cublasOperation_t},
47054705
transb_array::Ptr{cublasOperation_t},
4706-
m_array::CuPtr{Cint},
4707-
n_array::CuPtr{Cint},
4708-
k_array::CuPtr{Cint},
4709-
alpha_array::CuPtr{Float64},
4706+
m_array::Ptr{Cint},
4707+
n_array::Ptr{Cint},
4708+
k_array::Ptr{Cint},
4709+
alpha_array::Ptr{Float64},
47104710
Aarray::CuPtr{Ptr{Float64}},
4711-
lda_array::CuPtr{Cint},
4711+
lda_array::Ptr{Cint},
47124712
Barray::CuPtr{Ptr{Float64}},
4713-
ldb_array::CuPtr{Cint},
4714-
beta_array::CuPtr{Float64},
4713+
ldb_array::Ptr{Cint},
4714+
beta_array::Ptr{Float64},
47154715
Carray::CuPtr{Ptr{Float64}},
4716-
ldc_array::CuPtr{Cint},
4716+
ldc_array::Ptr{Cint},
47174717
group_count::Cint,
4718-
group_size::CuPtr{Cint})::cublasStatus_t
4718+
group_size::Ptr{Cint})::cublasStatus_t
47194719
end
47204720

47214721
@checked function cublasDgemmGroupedBatched_64(handle, transa_array, transb_array, m_array,
@@ -4726,19 +4726,19 @@ end
47264726
@gcsafe_ccall libcublas.cublasDgemmGroupedBatched_64(handle::cublasHandle_t,
47274727
transa_array::Ptr{cublasOperation_t},
47284728
transb_array::Ptr{cublasOperation_t},
4729-
m_array::CuPtr{Int64},
4730-
n_array::CuPtr{Int64},
4731-
k_array::CuPtr{Int64},
4732-
alpha_array::CuPtr{Float64},
4729+
m_array::Ptr{Int64},
4730+
n_array::Ptr{Int64},
4731+
k_array::Ptr{Int64},
4732+
alpha_array::Ptr{Float64},
47334733
Aarray::CuPtr{Ptr{Float64}},
4734-
lda_array::CuPtr{Int64},
4734+
lda_array::Ptr{Int64},
47354735
Barray::CuPtr{Ptr{Float64}},
4736-
ldb_array::CuPtr{Int64},
4737-
beta_array::CuPtr{Float64},
4736+
ldb_array::Ptr{Int64},
4737+
beta_array::Ptr{Float64},
47384738
Carray::CuPtr{Ptr{Float64}},
4739-
ldc_array::CuPtr{Int64},
4739+
ldc_array::Ptr{Int64},
47404740
group_count::Int64,
4741-
group_size::CuPtr{Int64})::cublasStatus_t
4741+
group_size::Ptr{Int64})::cublasStatus_t
47424742
end
47434743

47444744
@checked function cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C,

lib/cublas/wrappers.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,73 @@ end
12111211
return CuArray(ptrs)
12121212
end
12131213

1214+
## (GE) general matrix-matrix multiplication grouped batched
1215+
for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGroupedBatched_64, :Float32),
1216+
(:cublasDgemmGroupedBatched, :cublasDgemmGroupedBatched_64, :Float64))
1217+
@eval begin
1218+
function gemm_grouped_batched!(transA::Vector{Char},
1219+
transB::Vector{Char},
1220+
alpha::Vector{$elty},
1221+
A::Vector{<:StridedCuMatrix{$elty}},
1222+
B::Vector{<:StridedCuMatrix{$elty}},
1223+
beta::Vector{$elty},
1224+
C::Vector{<:StridedCuMatrix{$elty}})
1225+
if length(A) != length(B) || length(A) != length(C)
1226+
throw(DimensionMismatch("A, B and C must contain the same number of matrices"))
1227+
end
1228+
1229+
group_count = length(A)
1230+
group_size = ones(Int64, group_count)
1231+
1232+
for i = 1:group_count
1233+
m = size(A[i], transA[i] == 'N' ? 1 : 2)
1234+
k = size(A[i], transA[i] == 'N' ? 2 : 1)
1235+
n = size(B[i], transB[i] == 'N' ? 2 : 1)
1236+
if m != size(C[i],1) || n != size(C[i],2) || k != size(B[i], transB[i] == 'N' ? 1 : 2)
1237+
throw(DimensionMismatch(""))
1238+
end
1239+
end
1240+
1241+
transa = convert.(cublasOperation_t, transA)
1242+
transb = convert.(cublasOperation_t, transB)
1243+
m = [size(A[i], transA[i] == 'N' ? 1 : 2) for i = 1 : group_count]
1244+
k = [size(A[i], transA[i] == 'N' ? 2 : 1) for i = 1 : group_count]
1245+
n = [size(B[i], transB[i] == 'N' ? 2 : 1) for i = 1 : group_count]
1246+
lda = [max(1,stride(A[i],2)) for i = 1 : group_count]
1247+
ldb = [max(1,stride(B[i],2)) for i = 1 : group_count]
1248+
ldc = [max(1,stride(C[i],2)) for i = 1 : group_count]
1249+
Aptrs = unsafe_batch(A)
1250+
Bptrs = unsafe_batch(B)
1251+
Cptrs = unsafe_batch(C)
1252+
1253+
if CUBLAS.version() >= v"12.0"
1254+
$fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
1255+
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
1256+
else
1257+
$fname(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
1258+
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
1259+
end
1260+
unsafe_free!(Cptrs)
1261+
unsafe_free!(Bptrs)
1262+
unsafe_free!(Aptrs)
1263+
1264+
C
1265+
end
1266+
end
1267+
end
1268+
1269+
function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char}, alpha::Vector{T},
1270+
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
1271+
beta = [zero(T) for i = 1:length(transA)]
1272+
C = CuMatrix{T}[similar(B[i], (size(A[i], transA[i] == 'N' ? 1 : 2), size(B[i], transB[i] == 'N' ? 2 : 1))) for i in 1:length(A)]
1273+
gemm_grouped_batched!(transA, transB, alpha, A, B, beta, C)
1274+
end
1275+
function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char},
1276+
A::Vector{<:StridedCuMatrix{T}}, B::Vector{<:StridedCuMatrix{T}}) where T
1277+
alpha = [one(T) for i = 1:length(transA)]
1278+
gemm_grouped_batched(transA, transB, alpha, A, B)
1279+
end
1280+
12141281
## (GE) general matrix-matrix multiplication batched
12151282
for (fname, fname_64, elty) in ((:cublasDgemmBatched, :cublasDgemmBatched_64, :Float64),
12161283
(:cublasSgemmBatched, :cublasSgemmBatched_64, :Float32),

res/wrap/cublas.toml

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,29 +1773,11 @@ needs_context = false
17731773
2 = "CuPtr{Cvoid}"
17741774

17751775
[api.cublasSgemmGroupedBatched.argtypes]
1776-
4 = "CuPtr{Cint}"
1777-
5 = "CuPtr{Cint}"
1778-
6 = "CuPtr{Cint}"
1779-
7 = "CuPtr{Float32}"
17801776
8 = "CuPtr{Ptr{Float32}}"
1781-
9 = "CuPtr{Cint}"
17821777
10 = "CuPtr{Ptr{Float32}}"
1783-
11 = "CuPtr{Cint}"
1784-
12 = "CuPtr{Float32}"
17851778
13 = "CuPtr{Ptr{Float32}}"
1786-
14 = "CuPtr{Cint}"
1787-
16 = "CuPtr{Cint}"
17881779

17891780
[api.cublasDgemmGroupedBatched.argtypes]
1790-
4 = "CuPtr{Cint}"
1791-
5 = "CuPtr{Cint}"
1792-
6 = "CuPtr{Cint}"
1793-
7 = "CuPtr{Float64}"
17941781
8 = "CuPtr{Ptr{Float64}}"
1795-
9 = "CuPtr{Cint}"
17961782
10 = "CuPtr{Ptr{Float64}}"
1797-
11 = "CuPtr{Cint}"
1798-
12 = "CuPtr{Float64}"
17991783
13 = "CuPtr{Ptr{Float64}}"
1800-
14 = "CuPtr{Cint}"
1801-
16 = "CuPtr{Cint}"

test/libraries/cublas.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1597,7 +1597,7 @@ end
15971597
end
15981598
end
15991599

1600-
@testset for elty in [Float16, Float32, Float64, ComplexF32, ComplexF64]
1600+
@testset "elty = $elty" for elty in [Float16, Float32, Float64, ComplexF32, ComplexF64]
16011601
elty == Float16 && capability(device()) < v"5.3" && continue
16021602

16031603
alpha = rand(elty)
@@ -1711,6 +1711,48 @@ end
17111711
end
17121712
end
17131713

1714+
if CUDA.CUBLAS.version() >= v"12.4.2"
1715+
@testset "elty = $elty" for elty in [Float32, Float64]
1716+
1717+
transA = ['N' for i in 1:10]
1718+
transB = ['N' for i in 1:10]
1719+
alpha = rand(elty, 10)
1720+
beta = rand(elty, 10)
1721+
# generate matrices
1722+
bA = [rand(elty,3*i,2*i) for i in 1:10]
1723+
bB = [rand(elty,2*i,5*i) for i in 1:10]
1724+
bC = [rand(elty,3*i,5*i) for i in 1:10]
1725+
# move to device
1726+
bd_A = CuArray{elty, 2}[]
1727+
bd_B = CuArray{elty, 2}[]
1728+
bd_C = CuArray{elty, 2}[]
1729+
for i in 1:length(bA)
1730+
push!(bd_A,CuArray(bA[i]))
1731+
push!(bd_B,CuArray(bB[i]))
1732+
push!(bd_C,CuArray(bC[i]))
1733+
end
1734+
1735+
@testset "gemm_grouped_batched!" begin
1736+
# C = (alpha*A)*B + beta*C
1737+
CUBLAS.gemm_grouped_batched!(transA,transB,alpha,bd_A,bd_B,beta,bd_C)
1738+
for i in 1:length(bd_C)
1739+
bC[i] = alpha[i] * bA[i] * bB[i] + beta[i] * bC[i]
1740+
h_C = Array(bd_C[i])
1741+
@test bC[i] h_C
1742+
end
1743+
end
1744+
1745+
@testset "gemm_grouped_batched" begin
1746+
bd_C = CUBLAS.gemm_grouped_batched(transA,transB,bd_A,bd_B)
1747+
for i in 1:length(bd_C)
1748+
bC[i] = bA[i] * bB[i]
1749+
h_C = Array(bd_C[i])
1750+
@test bC[i] h_C
1751+
end
1752+
end
1753+
end
1754+
end
1755+
17141756
@testset "mixed-precision matmul" begin
17151757
m,k,n = 4,4,4
17161758
cudaTypes = (Float16, Complex{Float16}, BFloat16, Complex{BFloat16}, Float32, Complex{Float32},

0 commit comments

Comments
 (0)