@@ -434,7 +434,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
434
434
#define WARP_SIZE 32
435
435
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
436
436
437
- #define CUDA_ADDMUL_BLOCK_SIZE 256
438
437
#define CUDA_GELU_BLOCK_SIZE 256
439
438
#define CUDA_SILU_BLOCK_SIZE 256
440
439
#define CUDA_RELU_BLOCK_SIZE 256
@@ -501,6 +500,10 @@ static size_t g_scratch_offset = 0;
501
500
502
501
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr };
503
502
503
+ static __device__ __forceinline__ float op_repeat (const float a, const float b) {
504
+ return b;
505
+ }
506
+
504
507
static __device__ __forceinline__ float op_add (const float a, const float b) {
505
508
return a + b;
506
509
}
@@ -515,29 +518,69 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
515
518
516
519
template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
517
520
static __global__ void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst,
518
- int ne0,/* int ne1, int ne2, */ int ne3,
521
+ int ne0, int ne1, int ne2, int ne3,
519
522
int ne10, int ne11, int ne12, int ne13,
520
523
/* int s0, */ int s1, int s2, int s3,
521
524
/* int s10,*/ int s11, int s12, int s13) {
522
- const int i0 = blockDim .x *blockIdx .x + threadIdx .x ;
523
- const int i1 = blockIdx .y ;
524
- const int i2 = blockIdx .z / ne3;
525
- const int i3 = blockIdx .z % ne3;
525
+ const int i0s = blockDim .x *blockIdx .x + threadIdx .x ;
526
+ const int i1 = ( blockDim . y * blockIdx .y + threadIdx . y ) ;
527
+ const int i2 = ( blockDim . z * blockIdx .z + threadIdx . z ) / ne3;
528
+ const int i3 = ( blockDim . z * blockIdx .z + threadIdx . z ) % ne3;
526
529
527
- if (i0 >= ne0) {
530
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
531
+ return ;
532
+ }
533
+
534
+ const int i11 = i1 % ne11;
535
+ const int i12 = i2 % ne12;
536
+ const int i13 = i3 % ne13;
537
+
538
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
539
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
540
+ const size_t i_dst = i_src0;
541
+
542
+ const src0_t * src0_row = src0 + i_src0;
543
+ const src1_t * src1_row = src1 + i_src1;
544
+ dst_t * dst_row = dst + i_dst;
545
+
546
+ for (int i0 = i0s; i0 < ne0; i0 += blockDim .x *gridDim .x ) {
547
+ const int i10 = i0 % ne10;
548
+ dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
549
+ }
550
+ }
551
+
552
+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
553
+ static __global__ void k_bin_bcast_unravel (const src0_t * src0, const src1_t * src1, dst_t * dst,
554
+ int ne0, int ne1, int ne2, int ne3,
555
+ int ne10, int ne11, int ne12, int ne13,
556
+ /* int s0, */ int s1, int s2, int s3,
557
+ /* int s10,*/ int s11, int s12, int s13) {
558
+
559
+ const int i = blockDim .x *blockIdx .x + threadIdx .x ;
560
+
561
+ const int i3 = i/(ne2*ne1*ne0);
562
+ const int i2 = (i/(ne1*ne0)) % ne2;
563
+ const int i1 = (i/ne0) % ne1;
564
+ const int i0 = i % ne0;
565
+
566
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
528
567
return ;
529
568
}
530
569
531
- const int i10 = i0 % ne10;
532
570
const int i11 = i1 % ne11;
533
571
const int i12 = i2 % ne12;
534
572
const int i13 = i3 % ne13;
535
573
536
- const size_t i_dst = i3*s3 + i2*s2 + i1*s1 + i0;
537
- const size_t i_src0 = i_dst;
538
- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11 + i10;
574
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
575
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
576
+ const size_t i_dst = i_src0;
577
+
578
+ const src0_t * src0_row = src0 + i_src0;
579
+ const src1_t * src1_row = src1 + i_src1;
580
+ dst_t * dst_row = dst + i_dst;
539
581
540
- dst[i_dst] = (dst_t )bin_op ((float )src0[i_src0], (float )src1[i_src1]);
582
+ const int i10 = i0 % ne10;
583
+ dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
541
584
}
542
585
543
586
static __global__ void gelu_f32 (const float * x, float * dst, const int k) {
@@ -4849,24 +4892,108 @@ struct bin_bcast_cuda {
4849
4892
4850
4893
GGML_TENSOR_BINARY_OP_LOCALS
4851
4894
4852
- // size_t s0 = nb0 / sizeof(src1_t);
4853
- size_t s1 = nb1 / sizeof (src1_t );
4854
- size_t s2 = nb2 / sizeof (src1_t );
4855
- size_t s3 = nb3 / sizeof (src1_t );
4856
-
4857
- // size_t s10 = nb10 / sizeof(src1_t);
4858
- size_t s11 = nb11 / sizeof (src1_t );
4859
- size_t s12 = nb12 / sizeof (src1_t );
4860
- size_t s13 = nb13 / sizeof (src1_t );
4861
4895
4862
- const int num_blocks_x = (ne0 + CUDA_ADDMUL_BLOCK_SIZE - 1 ) / CUDA_ADDMUL_BLOCK_SIZE;
4863
- dim3 num_blocks (num_blocks_x, ne1, ne2*ne3);
4896
+ int nr0 = ne10/ne0;
4897
+ int nr1 = ne11/ne1;
4898
+ int nr2 = ne12/ne2;
4899
+ int nr3 = ne13/ne3;
4900
+
4901
+ int nr[4 ] = { nr0, nr1, nr2, nr3 };
4902
+
4903
+ // collapse dimensions until first broadcast dimension
4904
+ int64_t cne0[] = {ne0, ne1, ne2, ne3};
4905
+ int64_t cne1[] = {ne10, ne11, ne12, ne13};
4906
+ size_t cnb0[] = {nb0, nb1, nb2, nb3};
4907
+ size_t cnb1[] = {nb10, nb11, nb12, nb13};
4908
+ auto collapse = [](int64_t cne[]) {
4909
+ cne[0 ] *= cne[1 ];
4910
+ cne[1 ] = cne[2 ];
4911
+ cne[2 ] = cne[3 ];
4912
+ cne[3 ] = 1 ;
4913
+ };
4914
+
4915
+ auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
4916
+ cnb[1 ] *= cne[1 ];
4917
+ cnb[2 ] *= cne[2 ];
4918
+ cnb[3 ] *= cne[3 ];
4919
+ };
4920
+
4921
+ for (int i = 0 ; i < 4 ; i++) {
4922
+ if (nr[i] != 1 ) {
4923
+ break ;
4924
+ }
4925
+ if (i > 0 ) {
4926
+ collapse_nb (cnb0, cne0);
4927
+ collapse_nb (cnb1, cne1);
4928
+ collapse (cne0);
4929
+ collapse (cne1);
4930
+ }
4931
+ }
4932
+ {
4933
+ int64_t ne0 = cne0[0 ];
4934
+ int64_t ne1 = cne0[1 ];
4935
+ int64_t ne2 = cne0[2 ];
4936
+ int64_t ne3 = cne0[3 ];
4937
+
4938
+ int64_t ne10 = cne1[0 ];
4939
+ int64_t ne11 = cne1[1 ];
4940
+ int64_t ne12 = cne1[2 ];
4941
+ int64_t ne13 = cne1[3 ];
4942
+
4943
+ // size_t nb0 = cnb0[0];
4944
+ size_t nb1 = cnb0[1 ];
4945
+ size_t nb2 = cnb0[2 ];
4946
+ size_t nb3 = cnb0[3 ];
4947
+
4948
+ // size_t nb10 = cnb1[0];
4949
+ size_t nb11 = cnb1[1 ];
4950
+ size_t nb12 = cnb1[2 ];
4951
+ size_t nb13 = cnb1[3 ];
4952
+
4953
+ // size_t s0 = nb0 / sizeof(src1_t);
4954
+ size_t s1 = nb1 / sizeof (src1_t );
4955
+ size_t s2 = nb2 / sizeof (src1_t );
4956
+ size_t s3 = nb3 / sizeof (src1_t );
4957
+
4958
+ // size_t s10 = nb10 / sizeof(src1_t);
4959
+ size_t s11 = nb11 / sizeof (src1_t );
4960
+ size_t s12 = nb12 / sizeof (src1_t );
4961
+ size_t s13 = nb13 / sizeof (src1_t );
4962
+
4963
+
4964
+ const int block_size = 128 ;
4965
+
4966
+ int64_t hne0 = std::max (ne0/2LL , 1LL );
4967
+
4968
+ dim3 block_dims;
4969
+ block_dims.x = std::min<unsigned int >(hne0, block_size);
4970
+ block_dims.y = std::min<unsigned int >(ne1, block_size / block_dims.x );
4971
+ block_dims.z = std::min (std::min<unsigned int >(ne2*ne3, block_size / block_dims.x / block_dims.y ), 64U );
4972
+
4973
+ dim3 block_nums (
4974
+ (hne0 + block_dims.x - 1 ) / block_dims.x ,
4975
+ (ne1 + block_dims.y - 1 ) / block_dims.y ,
4976
+ (ne2*ne3 + block_dims.z - 1 ) / block_dims.z
4977
+ );
4864
4978
4865
- k_bin_bcast<bin_op><<<num_blocks, CUDA_ADDMUL_BLOCK_SIZE, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
4866
- ne0,/* ne1, ne2, */ ne3,
4867
- ne10, ne11, ne12, ne13,
4868
- /* s0, */ s1, s2, s3,
4869
- /* s10,*/ s11, s12, s13);
4979
+ if (block_nums.z > 65535 ) {
4980
+ // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
4981
+ int block_num = (ne0*ne1*ne2*ne3 + block_size - 1 ) / block_size;
4982
+ k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0 , stream>>> (
4983
+ src0_dd, src1_dd, dst_dd,
4984
+ ne0, ne1, ne2, ne3,
4985
+ ne10, ne11, ne12, ne13,
4986
+ /* s0, */ s1, s2, s3,
4987
+ /* s10, */ s11, s12, s13);
4988
+ } else {
4989
+ k_bin_bcast<bin_op><<<block_nums, block_dims, 0 , stream>>> (
4990
+ src0_dd, src1_dd, dst_dd,
4991
+ ne0, ne1, ne2, ne3,
4992
+ ne10, ne11, ne12, ne13,
4993
+ /* s0, */ s1, s2, s3,
4994
+ /* s10, */ s11, s12, s13);
4995
+ }
4996
+ }
4870
4997
}
4871
4998
};
4872
4999
@@ -6096,63 +6223,6 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
6096
6223
}
6097
6224
}
6098
6225
6099
- static void ggml_cuda_op_repeat (
6100
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6101
- const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
6102
- // guaranteed to be an integer due to the check in ggml_can_repeat
6103
- const int64_t ne0 = dst->ne [0 ];
6104
- const int64_t ne1 = dst->ne [1 ];
6105
- const int64_t ne2 = dst->ne [2 ];
6106
- const int64_t ne3 = dst->ne [3 ];
6107
-
6108
- const int64_t ne00 = src0->ne [0 ];
6109
- const int64_t ne01 = src0->ne [1 ];
6110
- const int64_t ne02 = src0->ne [2 ];
6111
- const int64_t ne03 = src0->ne [3 ];
6112
-
6113
- const size_t nb0 = dst->nb [0 ];
6114
- const size_t nb1 = dst->nb [1 ];
6115
- const size_t nb2 = dst->nb [2 ];
6116
- const size_t nb3 = dst->nb [3 ];
6117
-
6118
- const size_t nb00 = src0->nb [0 ];
6119
- const size_t nb01 = src0->nb [1 ];
6120
- const size_t nb02 = src0->nb [2 ];
6121
- const size_t nb03 = src0->nb [3 ];
6122
-
6123
- const int nr0 = (int )(ne0/ne00);
6124
- const int nr1 = (int )(ne1/ne01);
6125
- const int nr2 = (int )(ne2/ne02);
6126
- const int nr3 = (int )(ne3/ne03);
6127
-
6128
- // TODO: support for transposed / permuted tensors
6129
- GGML_ASSERT (nb0 == sizeof (float ));
6130
- GGML_ASSERT (nb00 == sizeof (float ));
6131
-
6132
- // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors
6133
- for (int i3 = 0 ; i3 < nr3; i3++) {
6134
- for (int k3 = 0 ; k3 < ne03; k3++) {
6135
- for (int i2 = 0 ; i2 < nr2; i2++) {
6136
- for (int k2 = 0 ; k2 < ne02; k2++) {
6137
- for (int i1 = 0 ; i1 < nr1; i1++) {
6138
- for (int k1 = 0 ; k1 < ne01; k1++) {
6139
- for (int i0 = 0 ; i0 < nr0; i0++) {
6140
- CUDA_CHECK (cudaMemcpyAsync (
6141
- (char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
6142
- (const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
6143
- ne00*nb0, cudaMemcpyDeviceToDevice, stream));
6144
- }
6145
- }
6146
- }
6147
- }
6148
- }
6149
- }
6150
- }
6151
-
6152
- (void ) src1;
6153
- (void ) src1_d;
6154
- }
6155
-
6156
6226
static void ggml_cuda_op_get_rows (
6157
6227
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6158
6228
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
@@ -6215,7 +6285,16 @@ inline void ggml_cuda_op_bin_bcast(
6215
6285
ggml_type_name (dst->type ), ggml_type_name (src0->type ), ggml_type_name (src1->type ));
6216
6286
GGML_ASSERT (false );
6217
6287
}
6288
+ }
6218
6289
6290
+ static void ggml_cuda_op_repeat (
6291
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6292
+ const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {
6293
+
6294
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr , src0_d, dst_d, main_stream);
6295
+
6296
+ (void ) src1;
6297
+ (void ) src1_d;
6219
6298
}
6220
6299
6221
6300
inline void ggml_cuda_op_add (
@@ -8393,7 +8472,8 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8393
8472
break ;
8394
8473
default :
8395
8474
return false ;
8396
- } break ;
8475
+ }
8476
+ break ;
8397
8477
case GGML_OP_NORM:
8398
8478
func = ggml_cuda_norm;
8399
8479
break ;
@@ -8842,10 +8922,10 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
8842
8922
UNUSED (backend);
8843
8923
}
8844
8924
8845
- static bool ggml_backend_cuda_supports_op (ggml_backend_t backend, const ggml_tensor * tensor ) {
8846
- switch (tensor ->op ) {
8925
+ static bool ggml_backend_cuda_supports_op (ggml_backend_t backend, const ggml_tensor * op ) {
8926
+ switch (op ->op ) {
8847
8927
case GGML_OP_UNARY:
8848
- switch (ggml_get_unary_op (tensor )) {
8928
+ switch (ggml_get_unary_op (op )) {
8849
8929
case GGML_UNARY_OP_GELU:
8850
8930
case GGML_UNARY_OP_SILU:
8851
8931
case GGML_UNARY_OP_RELU:
@@ -8854,7 +8934,23 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
8854
8934
return false ;
8855
8935
}
8856
8936
break ;
8937
+ case GGML_OP_MUL_MAT:
8857
8938
case GGML_OP_MUL_MAT_ID:
8939
+ {
8940
+ struct ggml_tensor * a;
8941
+ struct ggml_tensor * b;
8942
+ if (op->op == GGML_OP_MUL_MAT) {
8943
+ a = op->src [0 ];
8944
+ b = op->src [1 ];
8945
+ } else {
8946
+ a = op->src [2 ];
8947
+ b = op->src [1 ];
8948
+ }
8949
+ if (a->ne [3 ] != b->ne [3 ]) {
8950
+ return false ;
8951
+ }
8952
+ return true ;
8953
+ } break ;
8858
8954
case GGML_OP_NONE:
8859
8955
case GGML_OP_RESHAPE:
8860
8956
case GGML_OP_VIEW:
@@ -8868,7 +8964,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
8868
8964
case GGML_OP_MUL:
8869
8965
case GGML_OP_DIV:
8870
8966
case GGML_OP_RMS_NORM:
8871
- case GGML_OP_MUL_MAT:
8872
8967
case GGML_OP_SCALE:
8873
8968
case GGML_OP_SQR:
8874
8969
case GGML_OP_CLAMP:
@@ -8913,6 +9008,9 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
8913
9008
return nullptr ;
8914
9009
}
8915
9010
9011
+ // not strictly necessary, but it may reduce the overhead of the first graph_compute
9012
+ ggml_cuda_set_main_device (device);
9013
+
8916
9014
ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda {
8917
9015
/* .device = */ device
8918
9016
};
0 commit comments