Skip to content

Commit e05a9a1

Browse files
ggerganoviThalay
authored andcommitted
ggml : sync sycl (skip) (#0)
1 parent 0571114 commit e05a9a1

File tree

15 files changed

+1156
-129
lines changed

15 files changed

+1156
-129
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,8 @@
1919
#include "dmmv.hpp"
2020
#include "mmq.hpp"
2121
#include "mmvq.hpp"
22+
#include "rope.hpp"
23+
#include "norm.hpp"
24+
#include "softmax.hpp"
2225

2326
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/common.hpp

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <iostream>
1818

1919
#include "dpct/helper.hpp"
20+
#include "ggml-sycl.h"
2021
#include "presets.hpp"
2122

2223
#define GGML_COMMON_DECL_SYCL
@@ -46,10 +47,6 @@ static int g_ggml_sycl_debug = 0;
4647
} \
4748
}()
4849

49-
// #define DEBUG_SYCL_MALLOC
50-
51-
static int g_work_group_size = 0;
52-
// typedef sycl::half ggml_fp16_t;
5350

5451
#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
5552
#define VER_4VEC 610 // todo for hardward optimize.
@@ -192,6 +189,8 @@ struct ggml_sycl_device_info {
192189
sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};
193190

194191
std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
192+
193+
int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0};
195194
};
196195

197196
const ggml_sycl_device_info & ggml_sycl_info();
@@ -294,5 +293,57 @@ struct ggml_backend_sycl_context {
294293
}
295294
};
296295

296+
// common device functions
297+
298+
static __dpct_inline__ float warp_reduce_sum(float x,
299+
const sycl::nd_item<3>& item_ct1) {
300+
#pragma unroll
301+
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
302+
/*
303+
DPCT1096:98: The right-most dimension of the work-group used in the SYCL
304+
kernel that calls this function may be less than "32". The function
305+
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
306+
CPU device. Modify the size of the work-group to ensure that the value
307+
of the right-most dimension is a multiple of "32".
308+
*/
309+
x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
310+
}
311+
return x;
312+
}
313+
314+
static __dpct_inline__ sycl::float2
315+
warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
316+
#pragma unroll
317+
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
318+
a.x() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.x(),
319+
mask);
320+
a.y() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.y(),
321+
mask);
322+
}
323+
return a;
324+
}
325+
326+
static __dpct_inline__ float warp_reduce_max(float x,
327+
const sycl::nd_item<3>& item_ct1) {
328+
#pragma unroll
329+
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
330+
/*
331+
DPCT1096:97: The right-most dimension of the work-group used in the SYCL
332+
kernel that calls this function may be less than "32". The function
333+
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
334+
CPU device. Modify the size of the work-group to ensure that the value
335+
of the right-most dimension is a multiple of "32".
336+
*/
337+
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
338+
item_ct1.get_sub_group(), x, mask));
339+
}
340+
return x;
341+
}
342+
343+
// Helper for vec loading aligned data
344+
template <typename Tp, int n>
345+
inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
346+
return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
347+
}
297348

298349
#endif // GGML_SYCL_COMMON_HPP

ggml/src/ggml-sycl/convert.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
152152
dpct::has_capability_or_fail(stream->get_device(),
153153
{sycl::aspect::fp16});
154154

155-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
155+
stream->submit([&](sycl::handler &cgh) {
156+
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
157+
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
156158
sycl::range<3>(1, 1, 32),
157159
sycl::range<3>(1, 1, 32)),
158160
[=](sycl::nd_item<3> item_ct1) {
159-
dequantize_block_q4_K(vx, y, item_ct1);
161+
dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
160162
});
163+
});
161164
}
162165
}
163166

ggml/src/ggml-sycl/dequantize.hpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
293293
#if QK_K == 256
294294
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
295295
if (j < 4) {
296-
d = q[j] & 63; m = q[j + 4] & 63;
296+
d = q[j] & 63;
297+
m = q[j + 4] & 63;
297298
} else {
298299
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
299300
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
@@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
303304

304305
template<typename dst_t>
305306
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
306-
const sycl::nd_item<3> &item_ct1) {
307+
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
307308
const block_q4_K * x = (const block_q4_K *) vx;
308309

309310
const int i = item_ct1.get_group(2);
@@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
318319

319320
dst_t * y = yy + i*QK_K + 64*il + n*ir;
320321

321-
const float dall = x[i].dm[0];
322-
const float dmin = x[i].dm[1];
322+
const sycl::half2 dm = x[i].dm;
323+
const float dall = dm[0];
324+
const float dmin = dm[1];
323325

324-
const uint8_t * q = x[i].qs + 32*il + n*ir;
326+
if (tid < 12)
327+
scales_local[tid] = x[i].scales[tid];
328+
item_ct1.barrier(sycl::access::fence_space::local_space);
325329

326330
uint8_t sc, m;
327-
get_scale_min_k4(is + 0, x[i].scales, sc, m);
328-
const float d1 = dall * sc; const float m1 = dmin * m;
329-
get_scale_min_k4(is + 1, x[i].scales, sc, m);
330-
const float d2 = dall * sc; const float m2 = dmin * m;
331+
get_scale_min_k4(is + 0, scales_local, sc, m);
332+
const float d1 = dall * sc;
333+
const float m1 = dmin * m;
334+
get_scale_min_k4(is + 1, scales_local, sc, m);
335+
const float d2 = dall * sc;
336+
const float m2 = dmin * m;
337+
338+
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
331339
for (int l = 0; l < n; ++l) {
332-
y[l + 0] = d1 * (q[l] & 0xF) - m1;
333-
y[l +32] = d2 * (q[l] >> 4) - m2;
340+
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
341+
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
334342
}
335343
#else
336344
const int tid = item_ct1.get_local_id(2);

ggml/src/ggml-sycl/dmmv.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "dequantize.hpp"
44
#include "presets.hpp"
55

6+
67
static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
78
const sycl::half *x = (const sycl::half *)vx;
89

@@ -76,7 +77,7 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
7677

7778
// sum up partial sums and write back result
7879
#pragma unroll
79-
for (int mask = 16; mask > 0; mask >>= 1) {
80+
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
8081
tmp +=
8182
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8283
}
@@ -104,7 +105,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
104105

105106
stream->parallel_for(
106107
sycl::nd_range<3>(block_nums * block_dims, block_dims),
107-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
108+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
108109
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
109110
nrows, item_ct1);
110111
});
@@ -227,7 +228,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
227228

228229
// sum up partial sums and write back result
229230
#pragma unroll
230-
for (int mask = 16; mask > 0; mask >>= 1) {
231+
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
231232
tmp +=
232233
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
233234
}
@@ -346,7 +347,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
346347

347348
// sum up partial sums and write back result
348349
#pragma unroll
349-
for (int mask = 16; mask > 0; mask >>= 1) {
350+
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
350351
tmp +=
351352
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
352353
}
@@ -499,7 +500,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
499500

500501
// sum up partial sums and write back result
501502
#pragma unroll
502-
for (int mask = 16; mask > 0; mask >>= 1) {
503+
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
503504
tmp +=
504505
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
505506
}
@@ -633,7 +634,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
633634

634635
// sum up partial sums and write back result
635636
#pragma unroll
636-
for (int mask = 16; mask > 0; mask >>= 1) {
637+
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
637638
tmp +=
638639
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
639640
}
@@ -748,7 +749,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
748749

749750
// sum up partial sums and write back result
750751
#pragma unroll
751-
for (int mask = 16; mask > 0; mask >>= 1) {
752+
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
752753
tmp +=
753754
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
754755
}
@@ -774,7 +775,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
774775

775776
stream->parallel_for(
776777
sycl::nd_range<3>(block_nums * block_dims, block_dims),
777-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
778+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
778779
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
779780
vx, y, dst, ncols, nrows, item_ct1);
780781
});
@@ -795,7 +796,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
795796

796797
stream->parallel_for(
797798
sycl::nd_range<3>(block_nums * block_dims, block_dims),
798-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
799+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
799800
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
800801
vx, y, dst, ncols, nrows, item_ct1);
801802
});
@@ -816,7 +817,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
816817

817818
stream->parallel_for(
818819
sycl::nd_range<3>(block_nums * block_dims, block_dims),
819-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
820+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
820821
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
821822
vx, y, dst, ncols, nrows, item_ct1);
822823
});
@@ -837,7 +838,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
837838

838839
stream->parallel_for(
839840
sycl::nd_range<3>(block_nums * block_dims, block_dims),
840-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
841+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
841842
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
842843
vx, y, dst, ncols, nrows, item_ct1);
843844
});
@@ -858,7 +859,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
858859

859860
stream->parallel_for(
860861
sycl::nd_range<3>(block_nums * block_dims, block_dims),
861-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
862+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
862863
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
863864
vx, y, dst, ncols, nrows, item_ct1);
864865
});
@@ -873,10 +874,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
873874
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
874875
const int block_num_y = (nrows + ny - 1) / ny;
875876
const sycl::range<3> block_nums(1, 1, block_num_y);
876-
const sycl::range<3> block_dims(1, ny, 32);
877+
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
877878
stream->parallel_for(
878879
sycl::nd_range<3>(block_nums * block_dims, block_dims),
879-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
880+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
880881
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
881882
});
882883
}
@@ -889,10 +890,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
889890
const int ny = 2 / K_QUANTS_PER_ITERATION;
890891
const int block_num_y = (nrows + ny - 1) / ny;
891892
const sycl::range<3> block_nums(1, 1, block_num_y);
892-
const sycl::range<3> block_dims(1, ny, 32);
893+
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
893894
stream->parallel_for(
894895
sycl::nd_range<3>(block_nums * block_dims, block_dims),
895-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
896+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
896897
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
897898
});
898899
}
@@ -905,10 +906,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
905906
const int ny = 2 / K_QUANTS_PER_ITERATION;
906907
const int block_num_y = (nrows + ny - 1) / ny;
907908
const sycl::range<3> block_nums(1, 1, block_num_y);
908-
const sycl::range<3> block_dims(1, ny, 32);
909+
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
909910
stream->parallel_for(
910911
sycl::nd_range<3>(block_nums * block_dims, block_dims),
911-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
912+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
912913
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
913914
});
914915
}
@@ -918,10 +919,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
918919
const int nrows,
919920
dpct::queue_ptr stream) {
920921
GGML_ASSERT(ncols % QK_K == 0);
921-
const sycl::range<3> block_dims(1, 1, 32);
922+
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
922923
stream->parallel_for(
923924
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
924-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
925+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
925926
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
926927
});
927928
}
@@ -934,10 +935,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
934935
const int ny = 2 / K_QUANTS_PER_ITERATION;
935936
const int block_num_y = (nrows + ny - 1) / ny;
936937
const sycl::range<3> block_nums(1, 1, block_num_y);
937-
const sycl::range<3> block_dims(1, ny, 32);
938+
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
938939
stream->parallel_for(
939940
sycl::nd_range<3>(block_nums * block_dims, block_dims),
940-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
941+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
941942
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
942943
});
943944
}

0 commit comments

Comments
 (0)