@@ -47,6 +47,12 @@ using namespace dpctl::tensor::offset_utils;
47
47
template <typename srcT, typename dstT, typename IndexerT>
48
48
class copy_cast_generic_kernel ;
49
49
50
+ template <typename srcT,
51
+ typename dstT,
52
+ unsigned int vec_sz,
53
+ unsigned int n_vecs>
54
+ class copy_cast_contig_kernel ;
55
+
50
56
template <typename srcT, typename dstT, typename IndexerT>
51
57
class copy_cast_from_host_kernel ;
52
58
@@ -191,6 +197,166 @@ template <typename fnT, typename D, typename S> struct CopyAndCastGenericFactory
191
197
}
192
198
};
193
199
200
+ // Specialization of copy_and_cast for contiguous arrays of different data types
201
+
202
+ template <typename srcT,
203
+ typename dstT,
204
+ typename CastFnT,
205
+ int vec_sz = 4 ,
206
+ int n_vecs = 2 >
207
+ class ContigCopyFunctor
208
+ {
209
+ private:
210
+ const size_t nelems;
211
+ const srcT *src_p = nullptr ;
212
+ dstT *dst_p = nullptr ;
213
+
214
+ public:
215
+ ContigCopyFunctor (const size_t nelems_, const srcT *src_p_, dstT *dst_p_)
216
+ : nelems(nelems_), src_p(src_p_), dst_p(dst_p_)
217
+ {
218
+ }
219
+
220
+ void operator ()(sycl::nd_item<1 > ndit) const
221
+ {
222
+ CastFnT fn{};
223
+
224
+ using dpctl::tensor::type_utils::is_complex;
225
+ if constexpr (is_complex<srcT>::value || is_complex<dstT>::value) {
226
+ std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
227
+ size_t base = ndit.get_global_linear_id ();
228
+
229
+ base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
230
+ for (size_t offset = base;
231
+ offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
232
+ offset += sgSize)
233
+ {
234
+ dst_p[offset] = fn (src_p[offset]);
235
+ }
236
+ }
237
+ else {
238
+ auto sg = ndit.get_sub_group ();
239
+ std::uint8_t sgSize = sg.get_local_range ()[0 ];
240
+ std::uint8_t max_sgSize = sg.get_max_local_range ()[0 ];
241
+ size_t base = n_vecs * vec_sz *
242
+ (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
243
+ sg.get_group_id ()[0 ] * max_sgSize);
244
+
245
+ if (base + n_vecs * vec_sz * sgSize < nelems &&
246
+ sgSize == max_sgSize) {
247
+ using src_ptrT =
248
+ sycl::multi_ptr<const srcT,
249
+ sycl::access::address_space::global_space>;
250
+ using dst_ptrT =
251
+ sycl::multi_ptr<dstT,
252
+ sycl::access::address_space::global_space>;
253
+ sycl::vec<srcT, vec_sz> src_vec;
254
+ sycl::vec<dstT, vec_sz> dst_vec;
255
+
256
+ #pragma unroll
257
+ for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
258
+ src_vec =
259
+ sg.load <vec_sz>(src_ptrT (&src_p[base + it * sgSize]));
260
+ #pragma unroll
261
+ for (std::uint8_t k = 0 ; k < vec_sz; k++) {
262
+ dst_vec[k] = fn (src_vec[k]);
263
+ }
264
+ sg.store <vec_sz>(dst_ptrT (&dst_p[base + it * sgSize]),
265
+ dst_vec);
266
+ }
267
+ }
268
+ else {
269
+ for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
270
+ k += sgSize) {
271
+ dst_p[k] = fn (src_p[k]);
272
+ }
273
+ }
274
+ }
275
+ }
276
+ };
277
+
278
+ /* !
279
+ * @brief Function pointer type for contiguous array cast and copy function.
280
+ */
281
+ typedef sycl::event (*copy_and_cast_contig_fn_ptr_t )(
282
+ sycl::queue,
283
+ size_t ,
284
+ const char *,
285
+ char *,
286
+ const std::vector<sycl::event> &);
287
+
288
+ /* !
289
+ * @brief Function to copy `nelems` elements from contiguous `src` usm_ndarray
290
+ to contiguous `dst` usm_ndarray while casting from `srcTy` to `dstTy`.
291
+
292
+ Both arrays have the same number of elements `nelems`.
293
+ `src_cp` and `dst_cp` represent char pointers to the start of respective
294
+ arrays. Kernel is submitted to sycl queue `q` with events `depends` as
295
+ dependencies.
296
+
297
+ @param q Sycl queue to which the kernel is submitted.
298
+ @param nelems Number of elements to cast and copy.
299
+ @param src_p Kernel accessible USM pointer for the source array
300
+ @param dst_p Kernel accessible USM pointer for the destination array
301
+ @param depends List of events to wait for before starting computations, if
302
+ any.
303
+
304
+ @return Event to wait on to ensure that computation completes.
305
+ @ingroup CopyAndCastKernels
306
+ */
307
+ template <typename dstTy, typename srcTy>
308
+ sycl::event copy_and_cast_contig_impl (sycl::queue q,
309
+ size_t nelems,
310
+ const char *src_cp,
311
+ char *dst_cp,
312
+ const std::vector<sycl::event> &depends)
313
+ {
314
+ dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
315
+ dpctl::tensor::type_utils::validate_type_for_device<srcTy>(q);
316
+
317
+ sycl::event copy_and_cast_ev = q.submit ([&](sycl::handler &cgh) {
318
+ cgh.depends_on (depends);
319
+
320
+ const srcTy *src_tp = reinterpret_cast <const srcTy *>(src_cp);
321
+ dstTy *dst_tp = reinterpret_cast <dstTy *>(dst_cp);
322
+
323
+ size_t lws = 64 ;
324
+ constexpr unsigned int vec_sz = 4 ;
325
+ constexpr unsigned int n_vecs = 2 ;
326
+ const size_t n_groups =
327
+ ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
328
+ const auto gws_range = sycl::range<1 >(n_groups * lws);
329
+ const auto lws_range = sycl::range<1 >(lws);
330
+
331
+ cgh.parallel_for <copy_cast_contig_kernel<srcTy, dstTy, n_vecs, vec_sz>>(
332
+ sycl::nd_range<1 >(gws_range, lws_range),
333
+ ContigCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, vec_sz,
334
+ n_vecs>(nelems, src_tp, dst_tp));
335
+ });
336
+
337
+ return copy_and_cast_ev;
338
+ }
339
+
340
+ /* !
341
+ * @brief Factory to get specialized function pointer for casting and copying
342
+ * contiguous arrays of different types.
343
+ * @ingroup CopyAndCastKernels
344
+ */
345
+ template <typename fnT, typename D, typename S> struct CopyAndCastContigFactory
346
+ {
347
+ fnT get ()
348
+ {
349
+ if constexpr (std::is_same_v<D, S>) {
350
+ fnT fn = nullptr ;
351
+ return fn;
352
+ }
353
+ else {
354
+ fnT f = copy_and_cast_contig_impl<D, S>;
355
+ return f;
356
+ }
357
+ }
358
+ };
359
+
194
360
// Specialization of copy_and_cast for 1D arrays
195
361
196
362
/* !
0 commit comments