@@ -89,33 +89,24 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
89
89
sycl::range<3 > gridDim (ne2, ne1, num_blocks);
90
90
switch (dim) {
91
91
case 0 :
92
- sycl_parallel_for (stream,
93
- sycl::nd_range<3 >(gridDim *
94
- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
95
- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
96
- [=](sycl::nd_item<3 > item_ct1) {
97
- concat_f32_dim0 (x, y, dst, ne0, ne00, item_ct1);
98
- });
99
- break ;
92
+ sycl_parallel_for (stream,
93
+ sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
94
+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
95
+ [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim0 (x, y, dst, ne0, ne00, item_ct1); });
96
+ break ;
100
97
case 1 :
101
- sycl_parallel_for (stream,
102
- sycl::nd_range<3 >(gridDim *
103
- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
104
- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
105
- [=](sycl::nd_item<3 > item_ct1) {
106
- concat_f32_dim1 (x, y, dst, ne0, ne01, item_ct1);
107
- });
108
- break ;
98
+ sycl_parallel_for (stream,
99
+ sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
100
+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
101
+ [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim1 (x, y, dst, ne0, ne01, item_ct1); });
102
+ break ;
109
103
// dim >=2 will be dispatched to the default path
110
104
default :
111
- sycl_parallel_for (stream,
112
- sycl::nd_range<3 >(gridDim *
113
- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
114
- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
115
- [=](sycl::nd_item<3 > item_ct1) {
116
- concat_f32_dim2 (x, y, dst, ne0, ne02, item_ct1);
117
- });
118
- break ;
105
+ sycl_parallel_for (stream,
106
+ sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
107
+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
108
+ [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim2 (x, y, dst, ne0, ne02, item_ct1); });
109
+ break ;
119
110
}
120
111
}
121
112
@@ -129,33 +120,29 @@ static void concat_f32_sycl_non_cont(
129
120
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130
121
uint64_t nb3, int32_t dim) {
131
122
sycl::range<3 > gridDim (ne3, ne2, ne1);
132
- sycl_parallel_for (stream,
133
- sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )),
134
- [=](sycl::nd_item<3 > item_ct1) {
135
- int64_t i3 = item_ct1.get_group (0 );
136
- int64_t i2 = item_ct1.get_group (1 );
137
- int64_t i1 = item_ct1.get_group (2 );
123
+ sycl_parallel_for (stream, sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
124
+ int64_t i3 = item_ct1.get_group (0 );
125
+ int64_t i2 = item_ct1.get_group (1 );
126
+ int64_t i1 = item_ct1.get_group (2 );
138
127
139
- int64_t o[4 ] = {0 , 0 , 0 , 0 };
140
- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
128
+ int64_t o[4 ] = { 0 , 0 , 0 , 0 };
129
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
141
130
142
- const float *x;
131
+ const float * x;
143
132
144
- for (int i0 = item_ct1.get_local_id (2 ); i0 < ne0;
145
- i0 += item_ct1.get_local_range (2 )) {
133
+ for (int i0 = item_ct1.get_local_id (2 ); i0 < ne0; i0 += item_ct1.get_local_range (2 )) {
146
134
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
147
- x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
148
- (i0)*nb00);
135
+ x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
149
136
} else {
150
- x = (const float *)(src1 + (i3 - o[3 ]) * nb13 + (i2 - o[2 ]) * nb12 +
151
- (i1 - o[ 1 ]) * nb11 + (i0 - o[0 ]) * nb10);
137
+ x = (const float *) (src1 + (i3 - o[3 ]) * nb13 + (i2 - o[2 ]) * nb12 + (i1 - o[ 1 ]) * nb11 +
138
+ (i0 - o[0 ]) * nb10);
152
139
}
153
140
154
141
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
155
142
156
143
*y = *x;
157
- }
158
- });
144
+ }
145
+ });
159
146
}
160
147
161
148
void ggml_sycl_op_concat (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
0 commit comments