Skip to content

Commit 72fabb4

Browse files
committed
Convert all simple parallel_for to nd_launch from enqueue_functions
extension Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
1 parent df2b477 commit 72fabb4

File tree

11 files changed

+92
-92
lines changed

11 files changed

+92
-92
lines changed

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ struct bin_bcast_sycl {
225225
dpct::has_capability_or_fail(stream->get_device(),
226226
{sycl::aspect::fp16});
227227

228-
stream->parallel_for(
228+
syclex::nd_launch(*stream,
229229
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
230230
sycl::range<3>(1, 1, block_size),
231231
sycl::range<3>(1, 1, block_size)),
@@ -246,7 +246,7 @@ struct bin_bcast_sycl {
246246
dpct::has_capability_or_fail(stream->get_device(),
247247
{sycl::aspect::fp16});
248248

249-
stream->parallel_for(
249+
syclex::nd_launch(*stream,
250250
sycl::nd_range<3>(block_nums * block_dims, block_dims),
251251
[=](sycl::nd_item<3> item_ct1) {
252252
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,

ggml/src/ggml-sycl/concat.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
8989
sycl::range<3> gridDim(ne2, ne1, num_blocks);
9090
switch (dim) {
9191
case 0:
92-
stream->parallel_for(
92+
syclex::nd_launch(*stream,
9393
sycl::nd_range<3>(gridDim *
9494
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
9595
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
@@ -98,7 +98,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
9898
});
9999
break;
100100
case 1:
101-
stream->parallel_for(
101+
syclex::nd_launch(*stream,
102102
sycl::nd_range<3>(gridDim *
103103
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
104104
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
@@ -108,7 +108,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
108108
break;
109109
// dim >=2 will be dispatched to the default path
110110
default:
111-
stream->parallel_for(
111+
syclex::nd_launch(*stream,
112112
sycl::nd_range<3>(gridDim *
113113
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
114114
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
@@ -129,7 +129,7 @@ static void concat_f32_sycl_non_cont(
129129
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130130
uint64_t nb3, int32_t dim) {
131131
sycl::range<3> gridDim(ne3, ne2, ne1);
132-
stream->parallel_for(
132+
syclex::nd_launch(*stream,
133133
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
134134
[=](sycl::nd_item<3> item_ct1) {
135135
int64_t i3 = item_ct1.get_group(0);

ggml/src/ggml-sycl/conv.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ static void conv_transpose_1d_f32_f32_sycl(
5959
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
6060
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
6161
const sycl::range<3> block_nums(1, 1, num_blocks);
62-
stream->parallel_for(
62+
syclex::nd_launch(*stream,
6363
sycl::nd_range<3>(
6464
block_nums * block_dims, block_dims),
6565
[=](sycl::nd_item<3> item_ct1) {

ggml/src/ggml-sycl/convert.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
3333
{
3434
dpct::has_capability_or_fail(stream->get_device(),
3535
{sycl::aspect::fp16});
36-
stream->parallel_for(
36+
syclex::nd_launch(*stream,
3737
sycl::nd_range<3>(
3838
sycl::range<3>(1, 1, num_blocks) *
3939
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
@@ -53,7 +53,7 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
5353
dpct::has_capability_or_fail(stream->get_device(),
5454
{sycl::aspect::fp16});
5555

56-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
56+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
5757
sycl::range<3>(1, 1, 64),
5858
sycl::range<3>(1, 1, 64)),
5959
[=](sycl::nd_item<3> item_ct1) {
@@ -65,7 +65,7 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
6565
dpct::has_capability_or_fail(stream->get_device(),
6666
{sycl::aspect::fp16});
6767

68-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
68+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
6969
sycl::range<3>(1, 1, 32),
7070
sycl::range<3>(1, 1, 32)),
7171
[=](sycl::nd_item<3> item_ct1) {
@@ -85,7 +85,7 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
8585
dpct::has_capability_or_fail(stream->get_device(),
8686
{sycl::aspect::fp16});
8787

88-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
88+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
8989
sycl::range<3>(1, 1, 64),
9090
sycl::range<3>(1, 1, 64)),
9191
[=](sycl::nd_item<3> item_ct1) {
@@ -97,7 +97,7 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
9797
dpct::has_capability_or_fail(stream->get_device(),
9898
{sycl::aspect::fp16});
9999

100-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
100+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
101101
sycl::range<3>(1, 1, 32),
102102
sycl::range<3>(1, 1, 32)),
103103
[=](sycl::nd_item<3> item_ct1) {
@@ -116,7 +116,7 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
116116
dpct::has_capability_or_fail(stream->get_device(),
117117
{sycl::aspect::fp16});
118118

119-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
119+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
120120
sycl::range<3>(1, 1, 32),
121121
sycl::range<3>(1, 1, 32)),
122122
[=](sycl::nd_item<3> item_ct1) {
@@ -135,7 +135,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
135135
int constexpr WARP_K = WARP_SIZE * QK4_0;
136136
const int n_warp = (k + WARP_K - 1) / WARP_K;
137137
GGML_ASSERT(k % 2 == 0);
138-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
138+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
139139
sycl::range<3>(1, 1, WARP_SIZE),
140140
sycl::range<3>(1, 1, WARP_SIZE)),
141141
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
@@ -153,7 +153,7 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
153153
dpct::has_capability_or_fail(stream->get_device(),
154154
{sycl::aspect::fp16});
155155

156-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
156+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
157157
sycl::range<3>(1, 1, 32),
158158
sycl::range<3>(1, 1, 32)),
159159
[=](sycl::nd_item<3> item_ct1) {
@@ -210,7 +210,7 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
210210
dpct::has_capability_or_fail(stream->get_device(),
211211
{sycl::aspect::fp16});
212212

213-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
213+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
214214
sycl::range<3>(1, 1, 64),
215215
sycl::range<3>(1, 1, 64)),
216216
[=](sycl::nd_item<3> item_ct1) {
@@ -222,7 +222,7 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
222222
dpct::has_capability_or_fail(stream->get_device(),
223223
{sycl::aspect::fp16});
224224

225-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
225+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
226226
sycl::range<3>(1, 1, 32),
227227
sycl::range<3>(1, 1, 32)),
228228
[=](sycl::nd_item<3> item_ct1) {
@@ -242,7 +242,7 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
242242
dpct::has_capability_or_fail(stream->get_device(),
243243
{sycl::aspect::fp16});
244244

245-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
245+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
246246
sycl::range<3>(1, 1, 64),
247247
sycl::range<3>(1, 1, 64)),
248248
[=](sycl::nd_item<3> item_ct1) {
@@ -254,7 +254,7 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
254254
dpct::has_capability_or_fail(stream->get_device(),
255255
{sycl::aspect::fp16});
256256

257-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
257+
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
258258
sycl::range<3>(1, 1, 32),
259259
sycl::range<3>(1, 1, 32)),
260260
[=](sycl::nd_item<3> item_ct1) {
@@ -271,7 +271,7 @@ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const i
271271

272272
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
273273

274-
stream->parallel_for(
274+
syclex::nd_launch(*stream,
275275
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
276276
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
277277
}

0 commit comments

Comments
 (0)