@@ -4649,6 +4649,98 @@ end
4649
4649
algo:: cublasGemmAlgo_t ):: cublasStatus_t
4650
4650
end
4651
4651
4652
+ @checked function cublasSgemmGroupedBatched (handle, transa_array, transb_array, m_array,
4653
+ n_array, k_array, alpha_array, Aarray,
4654
+ lda_array, Barray, ldb_array, beta_array,
4655
+ Carray, ldc_array, group_count, group_size)
4656
+ initialize_context ()
4657
+ @gcsafe_ccall libcublas. cublasSgemmGroupedBatched (handle:: cublasHandle_t ,
4658
+ transa_array:: Ptr{cublasOperation_t} ,
4659
+ transb_array:: Ptr{cublasOperation_t} ,
4660
+ m_array:: CuPtr{Cint} ,
4661
+ n_array:: CuPtr{Cint} ,
4662
+ k_array:: CuPtr{Cint} ,
4663
+ alpha_array:: CuPtr{Float32} ,
4664
+ Aarray:: CuPtr{Ptr{Float32}} ,
4665
+ lda_array:: CuPtr{Cint} ,
4666
+ Barray:: CuPtr{Ptr{Float32}} ,
4667
+ ldb_array:: CuPtr{Cint} ,
4668
+ beta_array:: CuPtr{Float32} ,
4669
+ Carray:: CuPtr{Ptr{Float32}} ,
4670
+ ldc_array:: CuPtr{Cint} ,
4671
+ group_count:: Cint ,
4672
+ group_size:: CuPtr{Cint} ):: cublasStatus_t
4673
+ end
4674
+
4675
+ @checked function cublasSgemmGroupedBatched_64 (handle, transa_array, transb_array, m_array,
4676
+ n_array, k_array, alpha_array, Aarray,
4677
+ lda_array, Barray, ldb_array, beta_array,
4678
+ Carray, ldc_array, group_count, group_size)
4679
+ initialize_context ()
4680
+ @gcsafe_ccall libcublas. cublasSgemmGroupedBatched_64 (handle:: cublasHandle_t ,
4681
+ transa_array:: Ptr{cublasOperation_t} ,
4682
+ transb_array:: Ptr{cublasOperation_t} ,
4683
+ m_array:: CuPtr{Int64} ,
4684
+ n_array:: CuPtr{Int64} ,
4685
+ k_array:: CuPtr{Int64} ,
4686
+ alpha_array:: CuPtr{Float32} ,
4687
+ Aarray:: CuPtr{Ptr{Float32}} ,
4688
+ lda_array:: CuPtr{Int64} ,
4689
+ Barray:: CuPtr{Ptr{Float32}} ,
4690
+ ldb_array:: CuPtr{Int64} ,
4691
+ beta_array:: CuPtr{Float32} ,
4692
+ Carray:: CuPtr{Ptr{Float32}} ,
4693
+ ldc_array:: CuPtr{Int64} ,
4694
+ group_count:: Int64 ,
4695
+ group_size:: CuPtr{Int64} ):: cublasStatus_t
4696
+ end
4697
+
4698
+ @checked function cublasDgemmGroupedBatched (handle, transa_array, transb_array, m_array,
4699
+ n_array, k_array, alpha_array, Aarray,
4700
+ lda_array, Barray, ldb_array, beta_array,
4701
+ Carray, ldc_array, group_count, group_size)
4702
+ initialize_context ()
4703
+ @gcsafe_ccall libcublas. cublasDgemmGroupedBatched (handle:: cublasHandle_t ,
4704
+ transa_array:: Ptr{cublasOperation_t} ,
4705
+ transb_array:: Ptr{cublasOperation_t} ,
4706
+ m_array:: CuPtr{Cint} ,
4707
+ n_array:: CuPtr{Cint} ,
4708
+ k_array:: CuPtr{Cint} ,
4709
+ alpha_array:: CuPtr{Float64} ,
4710
+ Aarray:: CuPtr{Ptr{Float64}} ,
4711
+ lda_array:: CuPtr{Cint} ,
4712
+ Barray:: CuPtr{Ptr{Float64}} ,
4713
+ ldb_array:: CuPtr{Cint} ,
4714
+ beta_array:: CuPtr{Float64} ,
4715
+ Carray:: CuPtr{Ptr{Float64}} ,
4716
+ ldc_array:: CuPtr{Cint} ,
4717
+ group_count:: Cint ,
4718
+ group_size:: CuPtr{Cint} ):: cublasStatus_t
4719
+ end
4720
+
4721
+ @checked function cublasDgemmGroupedBatched_64 (handle, transa_array, transb_array, m_array,
4722
+ n_array, k_array, alpha_array, Aarray,
4723
+ lda_array, Barray, ldb_array, beta_array,
4724
+ Carray, ldc_array, group_count, group_size)
4725
+ initialize_context ()
4726
+ @gcsafe_ccall libcublas. cublasDgemmGroupedBatched_64 (handle:: cublasHandle_t ,
4727
+ transa_array:: Ptr{cublasOperation_t} ,
4728
+ transb_array:: Ptr{cublasOperation_t} ,
4729
+ m_array:: CuPtr{Int64} ,
4730
+ n_array:: CuPtr{Int64} ,
4731
+ k_array:: CuPtr{Int64} ,
4732
+ alpha_array:: CuPtr{Float64} ,
4733
+ Aarray:: CuPtr{Ptr{Float64}} ,
4734
+ lda_array:: CuPtr{Int64} ,
4735
+ Barray:: CuPtr{Ptr{Float64}} ,
4736
+ ldb_array:: CuPtr{Int64} ,
4737
+ beta_array:: CuPtr{Float64} ,
4738
+ Carray:: CuPtr{Ptr{Float64}} ,
4739
+ ldc_array:: CuPtr{Int64} ,
4740
+ group_count:: Int64 ,
4741
+ group_size:: CuPtr{Int64} ):: cublasStatus_t
4742
+ end
4743
+
4652
4744
@checked function cublasSgeam (handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C,
4653
4745
ldc)
4654
4746
initialize_context ()
0 commit comments