Skip to content

Commit 02d1f94

Browse files
Merge pull request #1165 from IntelPython/simplify-copy-and-cast-kernels
Simplify copy and cast kernels
2 parents ad02f84 + 6933675 commit 02d1f94

File tree

3 files changed

+245
-98
lines changed

3 files changed

+245
-98
lines changed

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 202 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -46,51 +46,52 @@ using namespace dpctl::tensor::offset_utils;
4646

4747
template <typename srcT, typename dstT, typename IndexerT>
4848
class copy_cast_generic_kernel;
49+
50+
template <typename srcT,
51+
typename dstT,
52+
unsigned int vec_sz,
53+
unsigned int n_vecs>
54+
class copy_cast_contig_kernel;
55+
4956
template <typename srcT, typename dstT, typename IndexerT>
5057
class copy_cast_from_host_kernel;
51-
// template <typename srcT, typename dstT, typename IndexerT>
52-
// class copy_cast_spec_kernel;
58+
5359
template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
5460
class copy_for_reshape_generic_kernel;
5561

56-
template <typename srcT, typename dstT> class Caster
62+
template <typename srcTy, typename dstTy> class Caster
5763
{
5864
public:
5965
Caster() = default;
60-
void operator()(const char *src,
61-
std::ptrdiff_t src_offset,
62-
char *dst,
63-
std::ptrdiff_t dst_offset) const
66+
dstTy operator()(const srcTy &src) const
6467
{
6568
using dpctl::tensor::type_utils::convert_impl;
66-
67-
const srcT *src_ = reinterpret_cast<const srcT *>(src) + src_offset;
68-
dstT *dst_ = reinterpret_cast<dstT *>(dst) + dst_offset;
69-
*dst_ = convert_impl<dstT, srcT>(*src_);
69+
return convert_impl<dstTy, srcTy>(src);
7070
}
7171
};
7272

73-
template <typename CastFnT, typename IndexerT> class GenericCopyFunctor
73+
template <typename srcT, typename dstT, typename CastFnT, typename IndexerT>
74+
class GenericCopyFunctor
7475
{
7576
private:
76-
const char *src_ = nullptr;
77-
char *dst_ = nullptr;
77+
const srcT *src_ = nullptr;
78+
dstT *dst_ = nullptr;
7879
IndexerT indexer_;
7980

8081
public:
81-
GenericCopyFunctor(const char *src_cp, char *dst_cp, IndexerT indexer)
82-
: src_(src_cp), dst_(dst_cp), indexer_(indexer)
82+
GenericCopyFunctor(const srcT *src_p, dstT *dst_p, IndexerT indexer)
83+
: src_(src_p), dst_(dst_p), indexer_(indexer)
8384
{
8485
}
8586

8687
void operator()(sycl::id<1> wiid) const
8788
{
88-
auto offsets = indexer_(static_cast<py::ssize_t>(wiid.get(0)));
89-
py::ssize_t src_offset = offsets.get_first_offset();
90-
py::ssize_t dst_offset = offsets.get_second_offset();
89+
const auto &offsets = indexer_(static_cast<py::ssize_t>(wiid.get(0)));
90+
const py::ssize_t &src_offset = offsets.get_first_offset();
91+
const py::ssize_t &dst_offset = offsets.get_second_offset();
9192

9293
CastFnT fn{};
93-
fn(src_, src_offset, dst_, dst_offset);
94+
dst_[dst_offset] = fn(src_[src_offset]);
9495
}
9596
};
9697

@@ -168,12 +169,15 @@ copy_and_cast_generic_impl(sycl::queue q,
168169

169170
TwoOffsets_StridedIndexer indexer{nd, src_offset, dst_offset,
170171
shape_and_strides};
172+
const srcTy *src_tp = reinterpret_cast<const srcTy *>(src_p);
173+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
171174

172175
cgh.parallel_for<class copy_cast_generic_kernel<
173176
srcTy, dstTy, TwoOffsets_StridedIndexer>>(
174177
sycl::range<1>(nelems),
175-
GenericCopyFunctor<Caster<srcTy, dstTy>, TwoOffsets_StridedIndexer>(
176-
src_p, dst_p, indexer));
178+
GenericCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>,
179+
TwoOffsets_StridedIndexer>(src_tp, dst_tp,
180+
indexer));
177181
});
178182

179183
return copy_and_cast_ev;
@@ -193,6 +197,160 @@ template <typename fnT, typename D, typename S> struct CopyAndCastGenericFactory
193197
}
194198
};
195199

200+
// Specialization of copy_and_cast for contiguous arrays
201+
202+
template <typename srcT,
203+
typename dstT,
204+
typename CastFnT,
205+
int vec_sz = 4,
206+
int n_vecs = 2>
207+
class ContigCopyFunctor
208+
{
209+
private:
210+
const size_t nelems;
211+
const srcT *src_p = nullptr;
212+
dstT *dst_p = nullptr;
213+
214+
public:
215+
ContigCopyFunctor(const size_t nelems_, const srcT *src_p_, dstT *dst_p_)
216+
: nelems(nelems_), src_p(src_p_), dst_p(dst_p_)
217+
{
218+
}
219+
220+
void operator()(sycl::nd_item<1> ndit) const
221+
{
222+
CastFnT fn{};
223+
224+
using dpctl::tensor::type_utils::is_complex;
225+
if constexpr (is_complex<srcT>::value || is_complex<dstT>::value) {
226+
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
227+
size_t base = ndit.get_global_linear_id();
228+
229+
base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
230+
for (size_t offset = base;
231+
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
232+
offset += sgSize)
233+
{
234+
dst_p[offset] = fn(src_p[offset]);
235+
}
236+
}
237+
else {
238+
auto sg = ndit.get_sub_group();
239+
std::uint8_t sgSize = sg.get_local_range()[0];
240+
std::uint8_t max_sgSize = sg.get_max_local_range()[0];
241+
size_t base = n_vecs * vec_sz *
242+
(ndit.get_group(0) * ndit.get_local_range(0) +
243+
sg.get_group_id()[0] * max_sgSize);
244+
245+
if (base + n_vecs * vec_sz * sgSize < nelems &&
246+
sgSize == max_sgSize) {
247+
using src_ptrT =
248+
sycl::multi_ptr<const srcT,
249+
sycl::access::address_space::global_space>;
250+
using dst_ptrT =
251+
sycl::multi_ptr<dstT,
252+
sycl::access::address_space::global_space>;
253+
sycl::vec<srcT, vec_sz> src_vec;
254+
sycl::vec<dstT, vec_sz> dst_vec;
255+
256+
#pragma unroll
257+
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
258+
src_vec =
259+
sg.load<vec_sz>(src_ptrT(&src_p[base + it * sgSize]));
260+
#pragma unroll
261+
for (std::uint8_t k = 0; k < vec_sz; k++) {
262+
dst_vec[k] = fn(src_vec[k]);
263+
}
264+
sg.store<vec_sz>(dst_ptrT(&dst_p[base + it * sgSize]),
265+
dst_vec);
266+
}
267+
}
268+
else {
269+
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
270+
k += sgSize) {
271+
dst_p[k] = fn(src_p[k]);
272+
}
273+
}
274+
}
275+
}
276+
};
277+
278+
/*!
279+
* @brief Function pointer type for contiguous array cast and copy function.
280+
*/
281+
typedef sycl::event (*copy_and_cast_contig_fn_ptr_t)(
282+
sycl::queue,
283+
size_t,
284+
const char *,
285+
char *,
286+
const std::vector<sycl::event> &);
287+
288+
/*!
289+
* @brief Function to copy `nelems` elements from contiguous `src` usm_ndarray
290+
to contiguous `dst` usm_ndarray while casting from `srcTy` to `dstTy`.
291+
292+
Both arrays have the same number of elements `nelems`.
293+
`src_cp` and `dst_cp` represent char pointers to the start of respective
294+
arrays. Kernel is submitted to sycl queue `q` with events `depends` as
295+
dependencies.
296+
297+
@param q Sycl queue to which the kernel is submitted.
298+
@param nelems Number of elements to cast and copy.
299+
@param src_p Kernel accessible USM pointer for the source array
300+
@param dst_p Kernel accessible USM pointer for the destination array
301+
@param depends List of events to wait for before starting computations, if
302+
any.
303+
304+
@return Event to wait on to ensure that computation completes.
305+
@ingroup CopyAndCastKernels
306+
*/
307+
template <typename dstTy, typename srcTy>
308+
sycl::event copy_and_cast_contig_impl(sycl::queue q,
309+
size_t nelems,
310+
const char *src_cp,
311+
char *dst_cp,
312+
const std::vector<sycl::event> &depends)
313+
{
314+
dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
315+
dpctl::tensor::type_utils::validate_type_for_device<srcTy>(q);
316+
317+
sycl::event copy_and_cast_ev = q.submit([&](sycl::handler &cgh) {
318+
cgh.depends_on(depends);
319+
320+
const srcTy *src_tp = reinterpret_cast<const srcTy *>(src_cp);
321+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_cp);
322+
323+
size_t lws = 64;
324+
constexpr unsigned int vec_sz = 4;
325+
constexpr unsigned int n_vecs = 2;
326+
const size_t n_groups =
327+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
328+
const auto gws_range = sycl::range<1>(n_groups * lws);
329+
const auto lws_range = sycl::range<1>(lws);
330+
331+
cgh.parallel_for<copy_cast_contig_kernel<srcTy, dstTy, n_vecs, vec_sz>>(
332+
sycl::nd_range<1>(gws_range, lws_range),
333+
ContigCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, vec_sz,
334+
n_vecs>(nelems, src_tp, dst_tp));
335+
});
336+
337+
return copy_and_cast_ev;
338+
}
339+
340+
/*!
341+
* @brief Factory to get specialized function pointer for casting and copying
342+
* contiguous arrays.
343+
* @ingroup CopyAndCastKernels
344+
*/
345+
template <typename fnT, typename D, typename S> struct CopyAndCastContigFactory
346+
{
347+
fnT get()
348+
{
349+
fnT f = copy_and_cast_contig_impl<D, S>;
350+
return f;
351+
}
352+
};
353+
196354
// Specialization of copy_and_cast for 1D arrays
197355

198356
/*!
@@ -276,13 +434,15 @@ copy_and_cast_nd_specialized_impl(sycl::queue q,
276434
using IndexerT = TwoOffsets_FixedDimStridedIndexer<nd>;
277435
IndexerT indexer{shape, src_strides, dst_strides, src_offset,
278436
dst_offset};
437+
const srcTy *src_tp = reinterpret_cast<const srcTy *>(src_p);
438+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
279439

280440
cgh.depends_on(depends);
281441
cgh.parallel_for<
282442
class copy_cast_generic_kernel<srcTy, dstTy, IndexerT>>(
283443
sycl::range<1>(nelems),
284-
GenericCopyFunctor<Caster<srcTy, dstTy>, IndexerT>(src_p, dst_p,
285-
indexer));
444+
GenericCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, IndexerT>(
445+
src_tp, dst_tp, indexer));
286446
});
287447

288448
return copy_and_cast_ev;
@@ -318,46 +478,33 @@ template <typename fnT, typename D, typename S> struct CopyAndCast2DFactory
318478

319479
// ====================== Copying from host to USM
320480

321-
template <typename srcT, typename dstT, typename AccessorT>
322-
class CasterForAccessor
323-
{
324-
public:
325-
CasterForAccessor() = default;
326-
void operator()(AccessorT src,
327-
std::ptrdiff_t src_offset,
328-
char *dst,
329-
std::ptrdiff_t dst_offset) const
330-
{
331-
using dpctl::tensor::type_utils::convert_impl;
332-
333-
dstT *dst_ = reinterpret_cast<dstT *>(dst) + dst_offset;
334-
*dst_ = convert_impl<dstT, srcT>(src[src_offset]);
335-
}
336-
};
337-
338-
template <typename CastFnT, typename AccessorT, typename IndexerT>
481+
template <typename AccessorT,
482+
typename dstTy,
483+
typename CastFnT,
484+
typename IndexerT>
339485
class GenericCopyFromHostFunctor
340486
{
341487
private:
342488
AccessorT src_acc_;
343-
char *dst_ = nullptr;
489+
dstTy *dst_ = nullptr;
344490
IndexerT indexer_;
345491

346492
public:
347493
GenericCopyFromHostFunctor(AccessorT src_acc,
348-
char *dst_cp,
494+
dstTy *dst_p,
349495
IndexerT indexer)
350-
: src_acc_(src_acc), dst_(dst_cp), indexer_(indexer)
496+
: src_acc_(src_acc), dst_(dst_p), indexer_(indexer)
351497
{
352498
}
353499

354500
void operator()(sycl::id<1> wiid) const
355501
{
356-
auto offsets = indexer_(static_cast<py::ssize_t>(wiid.get(0)));
357-
py::ssize_t src_offset = offsets.get_first_offset();
358-
py::ssize_t dst_offset = offsets.get_second_offset();
502+
const auto &offsets = indexer_(static_cast<py::ssize_t>(wiid.get(0)));
503+
const py::ssize_t &src_offset = offsets.get_first_offset();
504+
const py::ssize_t &dst_offset = offsets.get_second_offset();
505+
359506
CastFnT fn{};
360-
fn(src_acc_, src_offset, dst_, dst_offset);
507+
dst_[dst_offset] = fn(src_acc_[src_offset]);
361508
}
362509
};
363510

@@ -447,13 +594,15 @@ void copy_and_cast_from_host_impl(
447594
nd, src_offset - src_min_nelem_offset, dst_offset,
448595
const_cast<const py::ssize_t *>(shape_and_strides)};
449596

597+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
598+
450599
cgh.parallel_for<copy_cast_from_host_kernel<srcTy, dstTy,
451600
TwoOffsets_StridedIndexer>>(
452601
sycl::range<1>(nelems),
453-
GenericCopyFromHostFunctor<
454-
CasterForAccessor<srcTy, dstTy, decltype(npy_acc)>,
455-
decltype(npy_acc), TwoOffsets_StridedIndexer>(npy_acc, dst_p,
456-
indexer));
602+
GenericCopyFromHostFunctor<decltype(npy_acc), dstTy,
603+
Caster<srcTy, dstTy>,
604+
TwoOffsets_StridedIndexer>(
605+
npy_acc, dst_tp, indexer));
457606
});
458607

459608
// perform explicit synchronization. Implicit synchronization would be

0 commit comments

Comments
 (0)