@@ -46,51 +46,46 @@ using namespace dpctl::tensor::offset_utils;
46
46
47
47
template <typename srcT, typename dstT, typename IndexerT>
48
48
class copy_cast_generic_kernel ;
49
+
49
50
template <typename srcT, typename dstT, typename IndexerT>
50
51
class copy_cast_from_host_kernel ;
51
- // template <typename srcT, typename dstT, typename IndexerT>
52
- // class copy_cast_spec_kernel;
52
+
53
53
template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
54
54
class copy_for_reshape_generic_kernel ;
55
55
56
- template <typename srcT , typename dstT > class Caster
56
+ template <typename srcTy , typename dstTy > class Caster
57
57
{
58
58
public:
59
59
Caster () = default ;
60
- void operator ()(const char *src,
61
- std::ptrdiff_t src_offset,
62
- char *dst,
63
- std::ptrdiff_t dst_offset) const
60
+ dstTy operator ()(const srcTy &src) const
64
61
{
65
62
using dpctl::tensor::type_utils::convert_impl;
66
-
67
- const srcT *src_ = reinterpret_cast <const srcT *>(src) + src_offset;
68
- dstT *dst_ = reinterpret_cast <dstT *>(dst) + dst_offset;
69
- *dst_ = convert_impl<dstT, srcT>(*src_);
63
+ return convert_impl<dstTy, srcTy>(src);
70
64
}
71
65
};
72
66
73
- template <typename CastFnT, typename IndexerT> class GenericCopyFunctor
67
+ template <typename srcT, typename dstT, typename CastFnT, typename IndexerT>
68
+ class GenericCopyFunctor
74
69
{
75
70
private:
76
- const char *src_ = nullptr ;
77
- char *dst_ = nullptr ;
71
+ const srcT *src_ = nullptr ;
72
+ dstT *dst_ = nullptr ;
78
73
IndexerT indexer_;
79
74
80
75
public:
81
- GenericCopyFunctor (const char *src_cp, char *dst_cp , IndexerT indexer)
82
- : src_(src_cp ), dst_(dst_cp ), indexer_(indexer)
76
+ GenericCopyFunctor (const srcT *src_p, dstT *dst_p , IndexerT indexer)
77
+ : src_(src_p ), dst_(dst_p ), indexer_(indexer)
83
78
{
84
79
}
85
80
86
81
void operator ()(sycl::id<1 > wiid) const
87
82
{
88
- auto offsets = indexer_ (static_cast <py::ssize_t >(wiid.get (0 )));
89
- py::ssize_t src_offset = offsets.get_first_offset ();
90
- py::ssize_t dst_offset = offsets.get_second_offset ();
83
+ const auto & offsets = indexer_ (static_cast <py::ssize_t >(wiid.get (0 )));
84
+ const py::ssize_t & src_offset = offsets.get_first_offset ();
85
+ const py::ssize_t & dst_offset = offsets.get_second_offset ();
91
86
92
87
CastFnT fn{};
93
- fn (src_, src_offset, dst_, dst_offset );
88
+ dst_[dst_offset] = fn (src_[ src_offset] );
94
89
}
95
90
};
96
91
@@ -168,12 +163,15 @@ copy_and_cast_generic_impl(sycl::queue q,
168
163
169
164
TwoOffsets_StridedIndexer indexer{nd, src_offset, dst_offset,
170
165
shape_and_strides};
166
+ const srcTy *src_tp = reinterpret_cast <const srcTy *>(src_p);
167
+ dstTy *dst_tp = reinterpret_cast <dstTy *>(dst_p);
171
168
172
169
cgh.parallel_for <class copy_cast_generic_kernel <
173
170
srcTy, dstTy, TwoOffsets_StridedIndexer>>(
174
171
sycl::range<1 >(nelems),
175
- GenericCopyFunctor<Caster<srcTy, dstTy>, TwoOffsets_StridedIndexer>(
176
- src_p, dst_p, indexer));
172
+ GenericCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>,
173
+ TwoOffsets_StridedIndexer>(src_tp, dst_tp,
174
+ indexer));
177
175
});
178
176
179
177
return copy_and_cast_ev;
@@ -276,13 +274,15 @@ copy_and_cast_nd_specialized_impl(sycl::queue q,
276
274
using IndexerT = TwoOffsets_FixedDimStridedIndexer<nd>;
277
275
IndexerT indexer{shape, src_strides, dst_strides, src_offset,
278
276
dst_offset};
277
+ const srcTy *src_tp = reinterpret_cast <const srcTy *>(src_p);
278
+ dstTy *dst_tp = reinterpret_cast <dstTy *>(dst_p);
279
279
280
280
cgh.depends_on (depends);
281
281
cgh.parallel_for <
282
282
class copy_cast_generic_kernel <srcTy, dstTy, IndexerT>>(
283
283
sycl::range<1 >(nelems),
284
- GenericCopyFunctor<Caster<srcTy, dstTy>, IndexerT>(src_p, dst_p,
285
- indexer));
284
+ GenericCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, IndexerT>(
285
+ src_tp, dst_tp, indexer));
286
286
});
287
287
288
288
return copy_and_cast_ev;
@@ -318,46 +318,33 @@ template <typename fnT, typename D, typename S> struct CopyAndCast2DFactory
318
318
319
319
// ====================== Copying from host to USM
320
320
321
- template <typename srcT, typename dstT, typename AccessorT>
322
- class CasterForAccessor
323
- {
324
- public:
325
- CasterForAccessor () = default ;
326
- void operator ()(AccessorT src,
327
- std::ptrdiff_t src_offset,
328
- char *dst,
329
- std::ptrdiff_t dst_offset) const
330
- {
331
- using dpctl::tensor::type_utils::convert_impl;
332
-
333
- dstT *dst_ = reinterpret_cast <dstT *>(dst) + dst_offset;
334
- *dst_ = convert_impl<dstT, srcT>(src[src_offset]);
335
- }
336
- };
337
-
338
- template <typename CastFnT, typename AccessorT, typename IndexerT>
321
+ template <typename AccessorT,
322
+ typename dstTy,
323
+ typename CastFnT,
324
+ typename IndexerT>
339
325
class GenericCopyFromHostFunctor
340
326
{
341
327
private:
342
328
AccessorT src_acc_;
343
- char *dst_ = nullptr ;
329
+ dstTy *dst_ = nullptr ;
344
330
IndexerT indexer_;
345
331
346
332
public:
347
333
GenericCopyFromHostFunctor (AccessorT src_acc,
348
- char *dst_cp ,
334
+ dstTy *dst_p ,
349
335
IndexerT indexer)
350
- : src_acc_(src_acc), dst_(dst_cp ), indexer_(indexer)
336
+ : src_acc_(src_acc), dst_(dst_p ), indexer_(indexer)
351
337
{
352
338
}
353
339
354
340
void operator ()(sycl::id<1 > wiid) const
355
341
{
356
- auto offsets = indexer_ (static_cast <py::ssize_t >(wiid.get (0 )));
357
- py::ssize_t src_offset = offsets.get_first_offset ();
358
- py::ssize_t dst_offset = offsets.get_second_offset ();
342
+ const auto &offsets = indexer_ (static_cast <py::ssize_t >(wiid.get (0 )));
343
+ const py::ssize_t &src_offset = offsets.get_first_offset ();
344
+ const py::ssize_t &dst_offset = offsets.get_second_offset ();
345
+
359
346
CastFnT fn{};
360
- fn (src_acc_, src_offset, dst_, dst_offset );
347
+ dst_[dst_offset] = fn (src_acc_[ src_offset] );
361
348
}
362
349
};
363
350
@@ -447,13 +434,15 @@ void copy_and_cast_from_host_impl(
447
434
nd, src_offset - src_min_nelem_offset, dst_offset,
448
435
const_cast <const py::ssize_t *>(shape_and_strides)};
449
436
437
+ dstTy *dst_tp = reinterpret_cast <dstTy *>(dst_p);
438
+
450
439
cgh.parallel_for <copy_cast_from_host_kernel<srcTy, dstTy,
451
440
TwoOffsets_StridedIndexer>>(
452
441
sycl::range<1 >(nelems),
453
- GenericCopyFromHostFunctor<
454
- CasterForAccessor <srcTy, dstTy, decltype (npy_acc) >,
455
- decltype (npy_acc), TwoOffsets_StridedIndexer>(npy_acc, dst_p,
456
- indexer));
442
+ GenericCopyFromHostFunctor<decltype (npy_acc), dstTy,
443
+ Caster <srcTy, dstTy>,
444
+ TwoOffsets_StridedIndexer>(
445
+ npy_acc, dst_tp, indexer));
457
446
});
458
447
459
448
// perform explicit synchronization. Implicit synchronization would be
0 commit comments