Skip to content

Commit 2669110

Browse files
committed
Copy and cast kernel for contiguous data added
1 parent 6f7807a commit 2669110

File tree

2 files changed

+192
-11
lines changed

2 files changed

+192
-11
lines changed

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ using namespace dpctl::tensor::offset_utils;
4747
template <typename srcT, typename dstT, typename IndexerT>
4848
class copy_cast_generic_kernel;
4949

50+
template <typename srcT,
51+
typename dstT,
52+
unsigned int vec_sz,
53+
unsigned int n_vecs>
54+
class copy_cast_contig_kernel;
55+
5056
template <typename srcT, typename dstT, typename IndexerT>
5157
class copy_cast_from_host_kernel;
5258

@@ -191,6 +197,166 @@ template <typename fnT, typename D, typename S> struct CopyAndCastGenericFactory
191197
}
192198
};
193199

200+
// Specialization of copy_and_cast for contiguous arrays of different data types
201+
202+
template <typename srcT,
203+
typename dstT,
204+
typename CastFnT,
205+
int vec_sz = 4,
206+
int n_vecs = 2>
207+
class ContigCopyFunctor
208+
{
209+
private:
210+
const size_t nelems;
211+
const srcT *src_p = nullptr;
212+
dstT *dst_p = nullptr;
213+
214+
public:
215+
ContigCopyFunctor(const size_t nelems_, const srcT *src_p_, dstT *dst_p_)
216+
: nelems(nelems_), src_p(src_p_), dst_p(dst_p_)
217+
{
218+
}
219+
220+
void operator()(sycl::nd_item<1> ndit) const
221+
{
222+
CastFnT fn{};
223+
224+
using dpctl::tensor::type_utils::is_complex;
225+
if constexpr (is_complex<srcT>::value || is_complex<dstT>::value) {
226+
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
227+
size_t base = ndit.get_global_linear_id();
228+
229+
base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
230+
for (size_t offset = base;
231+
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
232+
offset += sgSize)
233+
{
234+
dst_p[offset] = fn(src_p[offset]);
235+
}
236+
}
237+
else {
238+
auto sg = ndit.get_sub_group();
239+
std::uint8_t sgSize = sg.get_local_range()[0];
240+
std::uint8_t max_sgSize = sg.get_max_local_range()[0];
241+
size_t base = n_vecs * vec_sz *
242+
(ndit.get_group(0) * ndit.get_local_range(0) +
243+
sg.get_group_id()[0] * max_sgSize);
244+
245+
if (base + n_vecs * vec_sz * sgSize < nelems &&
246+
sgSize == max_sgSize) {
247+
using src_ptrT =
248+
sycl::multi_ptr<const srcT,
249+
sycl::access::address_space::global_space>;
250+
using dst_ptrT =
251+
sycl::multi_ptr<dstT,
252+
sycl::access::address_space::global_space>;
253+
sycl::vec<srcT, vec_sz> src_vec;
254+
sycl::vec<dstT, vec_sz> dst_vec;
255+
256+
#pragma unroll
257+
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
258+
src_vec =
259+
sg.load<vec_sz>(src_ptrT(&src_p[base + it * sgSize]));
260+
#pragma unroll
261+
for (std::uint8_t k = 0; k < vec_sz; k++) {
262+
dst_vec[k] = fn(src_vec[k]);
263+
}
264+
sg.store<vec_sz>(dst_ptrT(&dst_p[base + it * sgSize]),
265+
dst_vec);
266+
}
267+
}
268+
else {
269+
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
270+
k += sgSize) {
271+
dst_p[k] = fn(src_p[k]);
272+
}
273+
}
274+
}
275+
}
276+
};
277+
278+
/*!
279+
* @brief Function pointer type for contiguous array cast and copy function.
280+
*/
281+
typedef sycl::event (*copy_and_cast_contig_fn_ptr_t)(
282+
sycl::queue,
283+
size_t,
284+
const char *,
285+
char *,
286+
const std::vector<sycl::event> &);
287+
288+
/*!
289+
* @brief Function to copy `nelems` elements from contiguous `src` usm_ndarray
290+
to contiguous `dst` usm_ndarray while casting from `srcTy` to `dstTy`.
291+
292+
Both arrays have the same number of elements `nelems`.
293+
`src_cp` and `dst_cp` represent char pointers to the start of respective
294+
arrays. Kernel is submitted to sycl queue `q` with events `depends` as
295+
dependencies.
296+
297+
@param q Sycl queue to which the kernel is submitted.
298+
@param nelems Number of elements to cast and copy.
299+
@param src_p Kernel accessible USM pointer for the source array
300+
@param dst_p Kernel accessible USM pointer for the destination array
301+
@param depends List of events to wait for before starting computations, if
302+
any.
303+
304+
@return Event to wait on to ensure that computation completes.
305+
@ingroup CopyAndCastKernels
306+
*/
307+
template <typename dstTy, typename srcTy>
308+
sycl::event copy_and_cast_contig_impl(sycl::queue q,
309+
size_t nelems,
310+
const char *src_cp,
311+
char *dst_cp,
312+
const std::vector<sycl::event> &depends)
313+
{
314+
dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
315+
dpctl::tensor::type_utils::validate_type_for_device<srcTy>(q);
316+
317+
sycl::event copy_and_cast_ev = q.submit([&](sycl::handler &cgh) {
318+
cgh.depends_on(depends);
319+
320+
const srcTy *src_tp = reinterpret_cast<const srcTy *>(src_cp);
321+
dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_cp);
322+
323+
size_t lws = 64;
324+
constexpr unsigned int vec_sz = 4;
325+
constexpr unsigned int n_vecs = 2;
326+
const size_t n_groups =
327+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
328+
const auto gws_range = sycl::range<1>(n_groups * lws);
329+
const auto lws_range = sycl::range<1>(lws);
330+
331+
cgh.parallel_for<copy_cast_contig_kernel<srcTy, dstTy, n_vecs, vec_sz>>(
332+
sycl::nd_range<1>(gws_range, lws_range),
333+
ContigCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, vec_sz,
334+
n_vecs>(nelems, src_tp, dst_tp));
335+
});
336+
337+
return copy_and_cast_ev;
338+
}
339+
340+
/*!
341+
* @brief Factory to get specialized function pointer for casting and copying
342+
* contiguous arrays of different types.
343+
* @ingroup CopyAndCastKernels
344+
*/
345+
template <typename fnT, typename D, typename S> struct CopyAndCastContigFactory
346+
{
347+
fnT get()
348+
{
349+
if constexpr (std::is_same_v<D, S>) {
350+
fnT fn = nullptr;
351+
return fn;
352+
}
353+
else {
354+
fnT f = copy_and_cast_contig_impl<D, S>;
355+
return f;
356+
}
357+
}
358+
};
359+
194360
// Specialization of copy_and_cast for 1D arrays
195361

196362
/*!

dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,15 @@ namespace py_internal
5252
namespace _ns = dpctl::tensor::detail;
5353

5454
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_1d_fn_ptr_t;
55+
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_contig_fn_ptr_t;
5556
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_generic_fn_ptr_t;
5657

5758
static copy_and_cast_generic_fn_ptr_t
5859
copy_and_cast_generic_dispatch_table[_ns::num_types][_ns::num_types];
5960
static copy_and_cast_1d_fn_ptr_t
6061
copy_and_cast_1d_dispatch_table[_ns::num_types][_ns::num_types];
62+
static copy_and_cast_contig_fn_ptr_t
63+
copy_and_cast_contig_dispatch_table[_ns::num_types][_ns::num_types];
6164

6265
namespace py = pybind11;
6366

@@ -139,26 +142,32 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
139142
bool is_dst_f_contig = dst.is_f_contiguous();
140143

141144
// check for applicability of special cases:
142-
// (same type && (both C-contiguous || both F-contiguous)
145+
// (both C-contiguous || both F-contiguous)
143146
bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
144147
bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
145148
if (both_c_contig || both_f_contig) {
149+
150+
sycl::event copy_ev;
146151
if (src_type_id == dst_type_id) {
147152

148153
int src_elem_size = src.get_elemsize();
149154

150-
sycl::event copy_ev =
151-
exec_q.memcpy(static_cast<void *>(dst_data),
152-
static_cast<const void *>(src_data),
153-
src_nelems * src_elem_size, depends);
154-
155-
// make sure src and dst are not GC-ed before copy_ev is complete
156-
return std::make_pair(
157-
keep_args_alive(exec_q, {src, dst}, {copy_ev}), copy_ev);
155+
copy_ev = exec_q.memcpy(static_cast<void *>(dst_data),
156+
static_cast<const void *>(src_data),
157+
src_nelems * src_elem_size, depends);
158158
}
159-
// With contract_iter2 in place, there is no need to write
160-
// dedicated kernels for casting between contiguous arrays
159+
else {
160+
auto fn =
161+
copy_and_cast_contig_dispatch_table[dst_type_id][src_type_id];
162+
copy_ev =
163+
contig_fn(exec_q, src_nelems, src_data, dst_data, depends);
164+
}
165+
// make sure src and dst are not GC-ed before copy_ev is complete
166+
return std::make_pair(keep_args_alive(exec_q, {src, dst}, {copy_ev}),
167+
copy_ev);
161168
}
169+
// With contract_iter2 in place, there is no need to write
170+
// dedicated kernels for casting between contiguous arrays
162171

163172
const py::ssize_t *src_strides = src.get_strides_raw();
164173
const py::ssize_t *dst_strides = dst.get_strides_raw();
@@ -259,6 +268,12 @@ void init_copy_and_cast_usm_to_usm_dispatch_tables(void)
259268
{
260269
using namespace dpctl::tensor::detail;
261270

271+
using dpctl::tensor::kernels::copy_and_cast::CopyAndCastContigFactory;
272+
DispatchTableBuilder<copy_and_cast_contig_fn_ptr_t,
273+
CopyAndCastContigFactory, num_types>
274+
dtb_contig;
275+
dtb_contig.populate_dispatch_table(copy_and_cast_contig_dispatch_table);
276+
262277
using dpctl::tensor::kernels::copy_and_cast::CopyAndCastGenericFactory;
263278
DispatchTableBuilder<copy_and_cast_generic_fn_ptr_t,
264279
CopyAndCastGenericFactory, num_types>

0 commit comments

Comments
 (0)