Skip to content

Commit bb39d3c

Browse files
Simplified copy-and-cast kernels
Functions should store typed pointers instead of typeless. The CastFnT effectively becomes a trivial call to convert_impl in its call operator. Also added few data movement optimizations.
1 parent 7bbfce1 commit bb39d3c

File tree

1 file changed

+42
-53
lines changed

1 file changed

+42
-53
lines changed

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 42 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -46,51 +46,46 @@ using namespace dpctl::tensor::offset_utils;
4646

4747
template <typename srcT, typename dstT, typename IndexerT>
4848
class copy_cast_generic_kernel;
49+
4950
template <typename srcT, typename dstT, typename IndexerT>
5051
class copy_cast_from_host_kernel;
51-
// template <typename srcT, typename dstT, typename IndexerT>
52-
// class copy_cast_spec_kernel;
52+
5353
template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
5454
class copy_for_reshape_generic_kernel;
5555

56-
template <typename srcT, typename dstT> class Caster
56+
template <typename srcTy, typename dstTy> class Caster
5757
{
5858
public:
5959
Caster() = default;
60-
void operator()(const char *src,
61-
std::ptrdiff_t src_offset,
62-
char *dst,
63-
std::ptrdiff_t dst_offset) const
60+
dstTy operator()(const srcTy &src) const
6461
{
6562
using dpctl::tensor::type_utils::convert_impl;
66-
67-
const srcT *src_ = reinterpret_cast<const srcT *>(src) + src_offset;
68-
dstT *dst_ = reinterpret_cast<dstT *>(dst) + dst_offset;
69-
*dst_ = convert_impl<dstT, srcT>(*src_);
63+
return convert_impl<dstTy, srcTy>(src);
7064
}
7165
};
7266

73-
template <typename CastFnT, typename IndexerT> class GenericCopyFunctor
67+
template <typename srcT, typename dstT, typename CastFnT, typename IndexerT>
68+
class GenericCopyFunctor
7469
{
7570
private:
76-
const char *src_ = nullptr;
77-
char *dst_ = nullptr;
71+
const srcT *src_ = nullptr;
72+
dstT *dst_ = nullptr;
7873
IndexerT indexer_;
7974

8075
public:
81-
GenericCopyFunctor(const char *src_cp, char *dst_cp, IndexerT indexer)
82-
: src_(src_cp), dst_(dst_cp), indexer_(indexer)
76+
GenericCopyFunctor(const srcT *src_p, dstT *dst_p, IndexerT indexer)
77+
: src_(src_p), dst_(dst_p), indexer_(indexer)
8378
{
8479
}
8580

8681
void operator()(sycl::id<1> wiid) const
8782
{
88-
auto offsets = indexer_(static_cast<py::ssize_t>(wiid.get(0)));
89-
py::ssize_t src_offset = offsets.get_first_offset();
90-
py::ssize_t dst_offset = offsets.get_second_offset();
83+
const auto &offsets = indexer_(static_cast<py::ssize_t>(wiid.get(0)));
84+
const py::ssize_t &src_offset = offsets.get_first_offset();
85+
const py::ssize_t &dst_offset = offsets.get_second_offset();
9186

9287
CastFnT fn{};
93-
fn(src_, src_offset, dst_, dst_offset);
88+
dst_[dst_offset] = fn(src_[src_offset]);
9489
}
9590
};
9691

@@ -168,12 +163,15 @@ copy_and_cast_generic_impl(sycl::queue q,
168163

169164
TwoOffsets_StridedIndexer indexer{nd, src_offset, dst_offset,
170165
shape_and_strides};
166+
const srcTy *src_tp = reinterpret_cast<const srcTy *>(src_p);
167+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
171168

172169
cgh.parallel_for<class copy_cast_generic_kernel<
173170
srcTy, dstTy, TwoOffsets_StridedIndexer>>(
174171
sycl::range<1>(nelems),
175-
GenericCopyFunctor<Caster<srcTy, dstTy>, TwoOffsets_StridedIndexer>(
176-
src_p, dst_p, indexer));
172+
GenericCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>,
173+
TwoOffsets_StridedIndexer>(src_tp, dst_tp,
174+
indexer));
177175
});
178176

179177
return copy_and_cast_ev;
@@ -276,13 +274,15 @@ copy_and_cast_nd_specialized_impl(sycl::queue q,
276274
using IndexerT = TwoOffsets_FixedDimStridedIndexer<nd>;
277275
IndexerT indexer{shape, src_strides, dst_strides, src_offset,
278276
dst_offset};
277+
const srcTy *src_tp = reinterpret_cast<const srcTy *>(src_p);
278+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
279279

280280
cgh.depends_on(depends);
281281
cgh.parallel_for<
282282
class copy_cast_generic_kernel<srcTy, dstTy, IndexerT>>(
283283
sycl::range<1>(nelems),
284-
GenericCopyFunctor<Caster<srcTy, dstTy>, IndexerT>(src_p, dst_p,
285-
indexer));
284+
GenericCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, IndexerT>(
285+
src_tp, dst_tp, indexer));
286286
});
287287

288288
return copy_and_cast_ev;
@@ -318,46 +318,33 @@ template <typename fnT, typename D, typename S> struct CopyAndCast2DFactory
318318

319319
// ====================== Copying from host to USM
320320

321-
template <typename srcT, typename dstT, typename AccessorT>
322-
class CasterForAccessor
323-
{
324-
public:
325-
CasterForAccessor() = default;
326-
void operator()(AccessorT src,
327-
std::ptrdiff_t src_offset,
328-
char *dst,
329-
std::ptrdiff_t dst_offset) const
330-
{
331-
using dpctl::tensor::type_utils::convert_impl;
332-
333-
dstT *dst_ = reinterpret_cast<dstT *>(dst) + dst_offset;
334-
*dst_ = convert_impl<dstT, srcT>(src[src_offset]);
335-
}
336-
};
337-
338-
template <typename CastFnT, typename AccessorT, typename IndexerT>
321+
template <typename AccessorT,
322+
typename dstTy,
323+
typename CastFnT,
324+
typename IndexerT>
339325
class GenericCopyFromHostFunctor
340326
{
341327
private:
342328
AccessorT src_acc_;
343-
char *dst_ = nullptr;
329+
dstTy *dst_ = nullptr;
344330
IndexerT indexer_;
345331

346332
public:
347333
GenericCopyFromHostFunctor(AccessorT src_acc,
348-
char *dst_cp,
334+
dstTy *dst_p,
349335
IndexerT indexer)
350-
: src_acc_(src_acc), dst_(dst_cp), indexer_(indexer)
336+
: src_acc_(src_acc), dst_(dst_p), indexer_(indexer)
351337
{
352338
}
353339

354340
void operator()(sycl::id<1> wiid) const
355341
{
356-
auto offsets = indexer_(static_cast<py::ssize_t>(wiid.get(0)));
357-
py::ssize_t src_offset = offsets.get_first_offset();
358-
py::ssize_t dst_offset = offsets.get_second_offset();
342+
const auto &offsets = indexer_(static_cast<py::ssize_t>(wiid.get(0)));
343+
const py::ssize_t &src_offset = offsets.get_first_offset();
344+
const py::ssize_t &dst_offset = offsets.get_second_offset();
345+
359346
CastFnT fn{};
360-
fn(src_acc_, src_offset, dst_, dst_offset);
347+
dst_[dst_offset] = fn(src_acc_[src_offset]);
361348
}
362349
};
363350

@@ -447,13 +434,15 @@ void copy_and_cast_from_host_impl(
447434
nd, src_offset - src_min_nelem_offset, dst_offset,
448435
const_cast<const py::ssize_t *>(shape_and_strides)};
449436

437+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);
438+
450439
cgh.parallel_for<copy_cast_from_host_kernel<srcTy, dstTy,
451440
TwoOffsets_StridedIndexer>>(
452441
sycl::range<1>(nelems),
453-
GenericCopyFromHostFunctor<
454-
CasterForAccessor<srcTy, dstTy, decltype(npy_acc)>,
455-
decltype(npy_acc), TwoOffsets_StridedIndexer>(npy_acc, dst_p,
456-
indexer));
442+
GenericCopyFromHostFunctor<decltype(npy_acc), dstTy,
443+
Caster<srcTy, dstTy>,
444+
TwoOffsets_StridedIndexer>(
445+
npy_acc, dst_tp, indexer));
457446
});
458447

459448
// perform explicit synchronization. Implicit synchronization would be

0 commit comments

Comments
 (0)