3
3
#include " dequantize.hpp"
4
4
#include " presets.hpp"
5
5
6
+
6
7
static void convert_f16 (const void * vx, const int ib, const int iqs, dfloat2 & v){
7
8
const sycl::half *x = (const sycl::half *)vx;
8
9
@@ -76,7 +77,7 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
76
77
77
78
// sum up partial sums and write back result
78
79
#pragma unroll
79
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
80
+ for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
80
81
tmp +=
81
82
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
82
83
}
@@ -104,7 +105,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
104
105
105
106
stream->parallel_for (
106
107
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
107
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
108
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
108
109
dequantize_mul_mat_vec<1 , 1 , convert_f16>(vx, y, dst, ncols,
109
110
nrows, item_ct1);
110
111
});
@@ -227,7 +228,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
227
228
228
229
// sum up partial sums and write back result
229
230
#pragma unroll
230
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
231
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
231
232
tmp +=
232
233
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
233
234
}
@@ -346,7 +347,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
346
347
347
348
// sum up partial sums and write back result
348
349
#pragma unroll
349
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
350
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
350
351
tmp +=
351
352
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
352
353
}
@@ -499,7 +500,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
499
500
500
501
// sum up partial sums and write back result
501
502
#pragma unroll
502
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
503
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
503
504
tmp +=
504
505
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
505
506
}
@@ -633,7 +634,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
633
634
634
635
// sum up partial sums and write back result
635
636
#pragma unroll
636
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
637
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
637
638
tmp +=
638
639
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
639
640
}
@@ -748,7 +749,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
748
749
749
750
// sum up partial sums and write back result
750
751
#pragma unroll
751
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
752
+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
752
753
tmp +=
753
754
dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
754
755
}
@@ -774,7 +775,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
774
775
775
776
stream->parallel_for (
776
777
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
777
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
778
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
778
779
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
779
780
vx, y, dst, ncols, nrows, item_ct1);
780
781
});
@@ -795,7 +796,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
795
796
796
797
stream->parallel_for (
797
798
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
798
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
799
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
799
800
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
800
801
vx, y, dst, ncols, nrows, item_ct1);
801
802
});
@@ -816,7 +817,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
816
817
817
818
stream->parallel_for (
818
819
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
819
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
820
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
820
821
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
821
822
vx, y, dst, ncols, nrows, item_ct1);
822
823
});
@@ -837,7 +838,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
837
838
838
839
stream->parallel_for (
839
840
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
840
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
841
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
841
842
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
842
843
vx, y, dst, ncols, nrows, item_ct1);
843
844
});
@@ -858,7 +859,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
858
859
859
860
stream->parallel_for (
860
861
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
861
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
862
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE )]] {
862
863
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
863
864
vx, y, dst, ncols, nrows, item_ct1);
864
865
});
@@ -873,10 +874,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
873
874
const int ny = 2 ; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
874
875
const int block_num_y = (nrows + ny - 1 ) / ny;
875
876
const sycl::range<3 > block_nums (1 , 1 , block_num_y);
876
- const sycl::range<3 > block_dims (1 , ny, 32 );
877
+ const sycl::range<3 > block_dims (1 , ny, QK_WARP_SIZE );
877
878
stream->parallel_for (
878
879
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
879
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
880
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
880
881
dequantize_mul_mat_vec_q2_k (vx, y, dst, ncols, nrows, item_ct1);
881
882
});
882
883
}
@@ -889,10 +890,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
889
890
const int ny = 2 / K_QUANTS_PER_ITERATION;
890
891
const int block_num_y = (nrows + ny - 1 ) / ny;
891
892
const sycl::range<3 > block_nums (1 , 1 , block_num_y);
892
- const sycl::range<3 > block_dims (1 , ny, 32 );
893
+ const sycl::range<3 > block_dims (1 , ny, QK_WARP_SIZE );
893
894
stream->parallel_for (
894
895
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
895
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
896
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
896
897
dequantize_mul_mat_vec_q3_k (vx, y, dst, ncols, nrows, item_ct1);
897
898
});
898
899
}
@@ -905,10 +906,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
905
906
const int ny = 2 / K_QUANTS_PER_ITERATION;
906
907
const int block_num_y = (nrows + ny - 1 ) / ny;
907
908
const sycl::range<3 > block_nums (1 , 1 , block_num_y);
908
- const sycl::range<3 > block_dims (1 , ny, 32 );
909
+ const sycl::range<3 > block_dims (1 , ny, QK_WARP_SIZE );
909
910
stream->parallel_for (
910
911
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
911
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
912
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
912
913
dequantize_mul_mat_vec_q4_k (vx, y, dst, ncols, nrows, item_ct1);
913
914
});
914
915
}
@@ -918,10 +919,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
918
919
const int nrows,
919
920
dpct::queue_ptr stream) {
920
921
GGML_ASSERT (ncols % QK_K == 0 );
921
- const sycl::range<3 > block_dims (1 , 1 , 32 );
922
+ const sycl::range<3 > block_dims (1 , 1 , QK_WARP_SIZE );
922
923
stream->parallel_for (
923
924
sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims, block_dims),
924
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
925
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
925
926
dequantize_mul_mat_vec_q5_k (vx, y, dst, ncols, item_ct1);
926
927
});
927
928
}
@@ -934,10 +935,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
934
935
const int ny = 2 / K_QUANTS_PER_ITERATION;
935
936
const int block_num_y = (nrows + ny - 1 ) / ny;
936
937
const sycl::range<3 > block_nums (1 , 1 , block_num_y);
937
- const sycl::range<3 > block_dims (1 , ny, 32 );
938
+ const sycl::range<3 > block_dims (1 , ny, QK_WARP_SIZE );
938
939
stream->parallel_for (
939
940
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
940
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
941
+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (QK_WARP_SIZE )]] {
941
942
dequantize_mul_mat_vec_q6_k (vx, y, dst, ncols, nrows, item_ct1);
942
943
});
943
944
}
0 commit comments