Skip to content

Commit 990f931

Browse files
slarenggerganov
andauthored
test-backend-ops : add performance eval mode + improve CUDA repeat and binary broadcast ops performance (#636)
* ggml-cuda : implement repeat with bin_bcast * ggml-cuda : change supports_op for mul_mat to match compute_forward * test-backend-ops : add performance eval mode * improve formatting * add sd test cases * fix test case * ggml-cuda : bin_bcast: better block sizes, two elements per thread * metal : add dim3 broadcast support for mul mat * cleanup * typo * metal : enable mul mat-vec for dim2 > 1 * metal : mul mat-vec support dim3 broadcasts ggml-ci * ggml-cuda : fix bin_bcast for ne0=1 ggml-ci * ggml-cuda : limit block size z dim to 64 * test-backend-ops : add test cases * test-backend-ops : add warmup run, print test type before trying to compute * ggml-cuda : bin_bcast: collapse dimensions when possible, add fallback kernel for large tensors ggml-ci * test-backend-ops : avoid division by zero --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 3f66942 commit 990f931

File tree

4 files changed

+761
-378
lines changed

4 files changed

+761
-378
lines changed

src/ggml-cuda.cu

Lines changed: 188 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
434434
#define WARP_SIZE 32
435435
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
436436

437-
#define CUDA_ADDMUL_BLOCK_SIZE 256
438437
#define CUDA_GELU_BLOCK_SIZE 256
439438
#define CUDA_SILU_BLOCK_SIZE 256
440439
#define CUDA_RELU_BLOCK_SIZE 256
@@ -501,6 +500,10 @@ static size_t g_scratch_offset = 0;
501500

502501
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
503502

503+
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
504+
return b;
505+
}
506+
504507
static __device__ __forceinline__ float op_add(const float a, const float b) {
505508
return a + b;
506509
}
@@ -515,29 +518,69 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
515518

516519
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
517520
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
518-
int ne0,/* int ne1, int ne2, */int ne3,
521+
int ne0, int ne1, int ne2, int ne3,
519522
int ne10, int ne11, int ne12, int ne13,
520523
/*int s0, */ int s1, int s2, int s3,
521524
/*int s10,*/ int s11, int s12, int s13) {
522-
const int i0 = blockDim.x*blockIdx.x + threadIdx.x;
523-
const int i1 = blockIdx.y;
524-
const int i2 = blockIdx.z / ne3;
525-
const int i3 = blockIdx.z % ne3;
525+
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
526+
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
527+
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
528+
const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
526529

527-
if (i0 >= ne0) {
530+
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
531+
return;
532+
}
533+
534+
const int i11 = i1 % ne11;
535+
const int i12 = i2 % ne12;
536+
const int i13 = i3 % ne13;
537+
538+
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
539+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
540+
const size_t i_dst = i_src0;
541+
542+
const src0_t * src0_row = src0 + i_src0;
543+
const src1_t * src1_row = src1 + i_src1;
544+
dst_t * dst_row = dst + i_dst;
545+
546+
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
547+
const int i10 = i0 % ne10;
548+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
549+
}
550+
}
551+
552+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
553+
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
554+
int ne0, int ne1, int ne2, int ne3,
555+
int ne10, int ne11, int ne12, int ne13,
556+
/*int s0, */ int s1, int s2, int s3,
557+
/*int s10,*/ int s11, int s12, int s13) {
558+
559+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
560+
561+
const int i3 = i/(ne2*ne1*ne0);
562+
const int i2 = (i/(ne1*ne0)) % ne2;
563+
const int i1 = (i/ne0) % ne1;
564+
const int i0 = i % ne0;
565+
566+
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
528567
return;
529568
}
530569

531-
const int i10 = i0 % ne10;
532570
const int i11 = i1 % ne11;
533571
const int i12 = i2 % ne12;
534572
const int i13 = i3 % ne13;
535573

536-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1 + i0;
537-
const size_t i_src0 = i_dst;
538-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11 + i10;
574+
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
575+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
576+
const size_t i_dst = i_src0;
577+
578+
const src0_t * src0_row = src0 + i_src0;
579+
const src1_t * src1_row = src1 + i_src1;
580+
dst_t * dst_row = dst + i_dst;
539581

540-
dst[i_dst] = (dst_t)bin_op((float)src0[i_src0], (float)src1[i_src1]);
582+
const int i10 = i0 % ne10;
583+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
541584
}
542585

543586
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
@@ -4849,24 +4892,108 @@ struct bin_bcast_cuda {
48494892

48504893
GGML_TENSOR_BINARY_OP_LOCALS
48514894

4852-
//size_t s0 = nb0 / sizeof(src1_t);
4853-
size_t s1 = nb1 / sizeof(src1_t);
4854-
size_t s2 = nb2 / sizeof(src1_t);
4855-
size_t s3 = nb3 / sizeof(src1_t);
4856-
4857-
//size_t s10 = nb10 / sizeof(src1_t);
4858-
size_t s11 = nb11 / sizeof(src1_t);
4859-
size_t s12 = nb12 / sizeof(src1_t);
4860-
size_t s13 = nb13 / sizeof(src1_t);
48614895

4862-
const int num_blocks_x = (ne0 + CUDA_ADDMUL_BLOCK_SIZE - 1) / CUDA_ADDMUL_BLOCK_SIZE;
4863-
dim3 num_blocks(num_blocks_x, ne1, ne2*ne3);
4896+
int nr0 = ne10/ne0;
4897+
int nr1 = ne11/ne1;
4898+
int nr2 = ne12/ne2;
4899+
int nr3 = ne13/ne3;
4900+
4901+
int nr[4] = { nr0, nr1, nr2, nr3 };
4902+
4903+
// collapse dimensions until first broadcast dimension
4904+
int64_t cne0[] = {ne0, ne1, ne2, ne3};
4905+
int64_t cne1[] = {ne10, ne11, ne12, ne13};
4906+
size_t cnb0[] = {nb0, nb1, nb2, nb3};
4907+
size_t cnb1[] = {nb10, nb11, nb12, nb13};
4908+
auto collapse = [](int64_t cne[]) {
4909+
cne[0] *= cne[1];
4910+
cne[1] = cne[2];
4911+
cne[2] = cne[3];
4912+
cne[3] = 1;
4913+
};
4914+
4915+
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
4916+
cnb[1] *= cne[1];
4917+
cnb[2] *= cne[2];
4918+
cnb[3] *= cne[3];
4919+
};
4920+
4921+
for (int i = 0; i < 4; i++) {
4922+
if (nr[i] != 1) {
4923+
break;
4924+
}
4925+
if (i > 0) {
4926+
collapse_nb(cnb0, cne0);
4927+
collapse_nb(cnb1, cne1);
4928+
collapse(cne0);
4929+
collapse(cne1);
4930+
}
4931+
}
4932+
{
4933+
int64_t ne0 = cne0[0];
4934+
int64_t ne1 = cne0[1];
4935+
int64_t ne2 = cne0[2];
4936+
int64_t ne3 = cne0[3];
4937+
4938+
int64_t ne10 = cne1[0];
4939+
int64_t ne11 = cne1[1];
4940+
int64_t ne12 = cne1[2];
4941+
int64_t ne13 = cne1[3];
4942+
4943+
//size_t nb0 = cnb0[0];
4944+
size_t nb1 = cnb0[1];
4945+
size_t nb2 = cnb0[2];
4946+
size_t nb3 = cnb0[3];
4947+
4948+
//size_t nb10 = cnb1[0];
4949+
size_t nb11 = cnb1[1];
4950+
size_t nb12 = cnb1[2];
4951+
size_t nb13 = cnb1[3];
4952+
4953+
//size_t s0 = nb0 / sizeof(src1_t);
4954+
size_t s1 = nb1 / sizeof(src1_t);
4955+
size_t s2 = nb2 / sizeof(src1_t);
4956+
size_t s3 = nb3 / sizeof(src1_t);
4957+
4958+
//size_t s10 = nb10 / sizeof(src1_t);
4959+
size_t s11 = nb11 / sizeof(src1_t);
4960+
size_t s12 = nb12 / sizeof(src1_t);
4961+
size_t s13 = nb13 / sizeof(src1_t);
4962+
4963+
4964+
const int block_size = 128;
4965+
4966+
int64_t hne0 = std::max(ne0/2LL, 1LL);
4967+
4968+
dim3 block_dims;
4969+
block_dims.x = std::min<unsigned int>(hne0, block_size);
4970+
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
4971+
block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
4972+
4973+
dim3 block_nums(
4974+
(hne0 + block_dims.x - 1) / block_dims.x,
4975+
(ne1 + block_dims.y - 1) / block_dims.y,
4976+
(ne2*ne3 + block_dims.z - 1) / block_dims.z
4977+
);
48644978

4865-
k_bin_bcast<bin_op><<<num_blocks, CUDA_ADDMUL_BLOCK_SIZE, 0, stream>>>(src0_dd, src1_dd, dst_dd,
4866-
ne0,/* ne1, ne2, */ne3,
4867-
ne10, ne11, ne12, ne13,
4868-
/* s0, */s1, s2, s3,
4869-
/* s10,*/ s11, s12, s13);
4979+
if (block_nums.z > 65535) {
4980+
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
4981+
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
4982+
k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
4983+
src0_dd, src1_dd, dst_dd,
4984+
ne0, ne1, ne2, ne3,
4985+
ne10, ne11, ne12, ne13,
4986+
/* s0, */ s1, s2, s3,
4987+
/* s10, */ s11, s12, s13);
4988+
} else {
4989+
k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
4990+
src0_dd, src1_dd, dst_dd,
4991+
ne0, ne1, ne2, ne3,
4992+
ne10, ne11, ne12, ne13,
4993+
/* s0, */ s1, s2, s3,
4994+
/* s10, */ s11, s12, s13);
4995+
}
4996+
}
48704997
}
48714998
};
48724999

@@ -6096,63 +6223,6 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
60966223
}
60976224
}
60986225

6099-
static void ggml_cuda_op_repeat(
6100-
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6101-
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
6102-
// guaranteed to be an integer due to the check in ggml_can_repeat
6103-
const int64_t ne0 = dst->ne[0];
6104-
const int64_t ne1 = dst->ne[1];
6105-
const int64_t ne2 = dst->ne[2];
6106-
const int64_t ne3 = dst->ne[3];
6107-
6108-
const int64_t ne00 = src0->ne[0];
6109-
const int64_t ne01 = src0->ne[1];
6110-
const int64_t ne02 = src0->ne[2];
6111-
const int64_t ne03 = src0->ne[3];
6112-
6113-
const size_t nb0 = dst->nb[0];
6114-
const size_t nb1 = dst->nb[1];
6115-
const size_t nb2 = dst->nb[2];
6116-
const size_t nb3 = dst->nb[3];
6117-
6118-
const size_t nb00 = src0->nb[0];
6119-
const size_t nb01 = src0->nb[1];
6120-
const size_t nb02 = src0->nb[2];
6121-
const size_t nb03 = src0->nb[3];
6122-
6123-
const int nr0 = (int)(ne0/ne00);
6124-
const int nr1 = (int)(ne1/ne01);
6125-
const int nr2 = (int)(ne2/ne02);
6126-
const int nr3 = (int)(ne3/ne03);
6127-
6128-
// TODO: support for transposed / permuted tensors
6129-
GGML_ASSERT(nb0 == sizeof(float));
6130-
GGML_ASSERT(nb00 == sizeof(float));
6131-
6132-
// TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors
6133-
for (int i3 = 0; i3 < nr3; i3++) {
6134-
for (int k3 = 0; k3 < ne03; k3++) {
6135-
for (int i2 = 0; i2 < nr2; i2++) {
6136-
for (int k2 = 0; k2 < ne02; k2++) {
6137-
for (int i1 = 0; i1 < nr1; i1++) {
6138-
for (int k1 = 0; k1 < ne01; k1++) {
6139-
for (int i0 = 0; i0 < nr0; i0++) {
6140-
CUDA_CHECK(cudaMemcpyAsync(
6141-
(char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
6142-
(const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
6143-
ne00*nb0, cudaMemcpyDeviceToDevice, stream));
6144-
}
6145-
}
6146-
}
6147-
}
6148-
}
6149-
}
6150-
}
6151-
6152-
(void) src1;
6153-
(void) src1_d;
6154-
}
6155-
61566226
static void ggml_cuda_op_get_rows(
61576227
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
61586228
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
@@ -6215,7 +6285,16 @@ inline void ggml_cuda_op_bin_bcast(
62156285
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
62166286
GGML_ASSERT(false);
62176287
}
6288+
}
62186289

6290+
static void ggml_cuda_op_repeat(
6291+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6292+
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {
6293+
6294+
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
6295+
6296+
(void) src1;
6297+
(void) src1_d;
62196298
}
62206299

62216300
inline void ggml_cuda_op_add(
@@ -8393,7 +8472,8 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
83938472
break;
83948473
default:
83958474
return false;
8396-
} break;
8475+
}
8476+
break;
83978477
case GGML_OP_NORM:
83988478
func = ggml_cuda_norm;
83998479
break;
@@ -8842,10 +8922,10 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
88428922
UNUSED(backend);
88438923
}
88448924

8845-
static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * tensor) {
8846-
switch (tensor->op) {
8925+
static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
8926+
switch (op->op) {
88478927
case GGML_OP_UNARY:
8848-
switch (ggml_get_unary_op(tensor)) {
8928+
switch (ggml_get_unary_op(op)) {
88498929
case GGML_UNARY_OP_GELU:
88508930
case GGML_UNARY_OP_SILU:
88518931
case GGML_UNARY_OP_RELU:
@@ -8854,7 +8934,23 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
88548934
return false;
88558935
}
88568936
break;
8937+
case GGML_OP_MUL_MAT:
88578938
case GGML_OP_MUL_MAT_ID:
8939+
{
8940+
struct ggml_tensor * a;
8941+
struct ggml_tensor * b;
8942+
if (op->op == GGML_OP_MUL_MAT) {
8943+
a = op->src[0];
8944+
b = op->src[1];
8945+
} else {
8946+
a = op->src[2];
8947+
b = op->src[1];
8948+
}
8949+
if (a->ne[3] != b->ne[3]) {
8950+
return false;
8951+
}
8952+
return true;
8953+
} break;
88588954
case GGML_OP_NONE:
88598955
case GGML_OP_RESHAPE:
88608956
case GGML_OP_VIEW:
@@ -8868,7 +8964,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
88688964
case GGML_OP_MUL:
88698965
case GGML_OP_DIV:
88708966
case GGML_OP_RMS_NORM:
8871-
case GGML_OP_MUL_MAT:
88728967
case GGML_OP_SCALE:
88738968
case GGML_OP_SQR:
88748969
case GGML_OP_CLAMP:
@@ -8913,6 +9008,9 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
89139008
return nullptr;
89149009
}
89159010

9011+
// not strictly necessary, but it may reduce the overhead of the first graph_compute
9012+
ggml_cuda_set_main_device(device);
9013+
89169014
ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda {
89179015
/* .device = */ device
89189016
};

0 commit comments

Comments
 (0)