Skip to content

Commit 0399b93

Browse files
Factor out in-line use of submit to init padded vector into separate function
Doing so reduces the binary size of elementwise operations extension Before: ``` (dev_dpctl) opavlyk@mtl-world:~/repos/dpctl$ ls -l dpctl/tensor/_tensor_elementwise_impl.cpython-312-x86_64-linux-gnu.so -rw-r--r-- 1 opavlyk opavlyk 38659896 Jan 19 20:58 dpctl/tensor/_tensor_elementwise_impl.cpython-312-x86_64-linux-gnu.so ``` After: ``` dev_dpctl) opavlyk@mtl-world:~/repos/dpctl$ ls -l dpctl/tensor/_tensor_elementwise_impl.cpython-312-x86_64-linux-gnu.so -rw-r--r-- 1 opavlyk opavlyk 37176600 Jan 21 06:36 dpctl/tensor/_tensor_elementwise_impl.cpython-312-x86_64-linux-gnu.so ``` Added static assertions to offset_utils to ensure that indexers are device copyable.
1 parent d9e9bf8 commit 0399b93

File tree

4 files changed

+155
-50
lines changed

4 files changed

+155
-50
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
#include <cstddef>
2727
#include <cstdint>
2828
#include <stdexcept>
29-
#include <sycl/sycl.hpp>
3029
#include <utility>
3130

31+
#include <sycl/sycl.hpp>
32+
3233
#include "kernels/alignment.hpp"
3334
#include "kernels/dpctl_tensor_types.hpp"
35+
#include "kernels/elementwise_functions/common_detail.hpp"
3436
#include "utils/offset_utils.hpp"
3537
#include "utils/sycl_alloc_utils.hpp"
3638
#include "utils/sycl_utils.hpp"
@@ -324,21 +326,23 @@ sycl::event unary_contig_impl(sycl::queue &exec_q,
324326
{
325327
constexpr bool enable_sg_loadstore = true;
326328
using KernelName = BaseKernelName;
329+
using Impl = ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
330+
enable_sg_loadstore>;
327331

328332
cgh.parallel_for<KernelName>(
329333
sycl::nd_range<1>(gws_range, lws_range),
330-
ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
331-
enable_sg_loadstore>(arg_tp, res_tp, nelems));
334+
Impl(arg_tp, res_tp, nelems));
332335
}
333336
else {
334337
constexpr bool disable_sg_loadstore = false;
335338
using KernelName =
336339
disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
340+
using Impl = ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
341+
disable_sg_loadstore>;
337342

338343
cgh.parallel_for<KernelName>(
339344
sycl::nd_range<1>(gws_range, lws_range),
340-
ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
341-
disable_sg_loadstore>(arg_tp, res_tp, nelems));
345+
Impl(arg_tp, res_tp, nelems));
342346
}
343347
});
344348

@@ -377,9 +381,10 @@ unary_strided_impl(sycl::queue &exec_q,
377381
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
378382
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
379383

384+
using Impl = StridedFunctorT<argTy, resTy, IndexerT>;
385+
380386
cgh.parallel_for<kernel_name<argTy, resTy, IndexerT>>(
381-
{nelems},
382-
StridedFunctorT<argTy, resTy, IndexerT>(arg_tp, res_tp, indexer));
387+
{nelems}, Impl(arg_tp, res_tp, indexer));
383388
});
384389
return comp_ev;
385390
}
@@ -814,22 +819,23 @@ sycl::event binary_contig_impl(sycl::queue &exec_q,
814819
{
815820
constexpr bool enable_sg_loadstore = true;
816821
using KernelName = BaseKernelName;
822+
using Impl = BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz,
823+
n_vecs, enable_sg_loadstore>;
817824

818825
cgh.parallel_for<KernelName>(
819826
sycl::nd_range<1>(gws_range, lws_range),
820-
BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
821-
enable_sg_loadstore>(arg1_tp, arg2_tp,
822-
res_tp, nelems));
827+
Impl(arg1_tp, arg2_tp, res_tp, nelems));
823828
}
824829
else {
825830
constexpr bool disable_sg_loadstore = false;
826831
using KernelName =
827832
disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
833+
using Impl = BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz,
834+
n_vecs, disable_sg_loadstore>;
835+
828836
cgh.parallel_for<KernelName>(
829837
sycl::nd_range<1>(gws_range, lws_range),
830-
BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
831-
disable_sg_loadstore>(arg1_tp, arg2_tp,
832-
res_tp, nelems));
838+
Impl(arg1_tp, arg2_tp, res_tp, nelems));
833839
}
834840
});
835841
return comp_ev;
@@ -873,9 +879,10 @@ binary_strided_impl(sycl::queue &exec_q,
873879
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
874880
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
875881

882+
using Impl = BinaryStridedFunctorT<argTy1, argTy2, resTy, IndexerT>;
883+
876884
cgh.parallel_for<kernel_name<argTy1, argTy2, resTy, IndexerT>>(
877-
{nelems}, BinaryStridedFunctorT<argTy1, argTy2, resTy, IndexerT>(
878-
arg1_tp, arg2_tp, res_tp, indexer));
885+
{nelems}, Impl(arg1_tp, arg2_tp, res_tp, indexer));
879886
});
880887
return comp_ev;
881888
}
@@ -917,13 +924,9 @@ sycl::event binary_contig_matrix_contig_row_broadcast_impl(
917924
exec_q);
918925
argT2 *padded_vec = padded_vec_owner.get();
919926

920-
sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
921-
cgh.depends_on(depends); // ensure vec contains actual data
922-
cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) {
923-
auto i = id[0];
924-
padded_vec[i] = vec[i % n1];
925-
});
926-
});
927+
sycl::event make_padded_vec_ev =
928+
dpctl::tensor::kernels::elementwise_detail::populate_padded_vector<
929+
argT2>(exec_q, vec, n1, padded_vec, n1_padded, depends);
927930

928931
// sub-group spans work-items [I, I + sgSize)
929932
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
@@ -942,10 +945,12 @@ sycl::event binary_contig_matrix_contig_row_broadcast_impl(
942945
std::size_t n_groups = (n_elems + lws - 1) / lws;
943946
auto gwsRange = sycl::range<1>(n_groups * lws);
944947

948+
using Impl =
949+
BinaryContigMatrixContigRowBroadcastFunctorT<argT1, argT2, resT>;
950+
945951
cgh.parallel_for<class kernel_name<argT1, argT2, resT>>(
946952
sycl::nd_range<1>(gwsRange, lwsRange),
947-
BinaryContigMatrixContigRowBroadcastFunctorT<argT1, argT2, resT>(
948-
mat, padded_vec, res, n_elems, n1));
953+
Impl(mat, padded_vec, res, n_elems, n1));
949954
});
950955

951956
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
@@ -993,13 +998,9 @@ sycl::event binary_contig_row_contig_matrix_broadcast_impl(
993998
exec_q);
994999
argT2 *padded_vec = padded_vec_owner.get();
9951000

996-
sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
997-
cgh.depends_on(depends); // ensure vec contains actual data
998-
cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) {
999-
auto i = id[0];
1000-
padded_vec[i] = vec[i % n1];
1001-
});
1002-
});
1001+
sycl::event make_padded_vec_ev =
1002+
dpctl::tensor::kernels::elementwise_detail::populate_padded_vector<
1003+
argT2>(exec_q, vec, n1, padded_vec, n1_padded, depends);
10031004

10041005
// sub-group spans work-items [I, I + sgSize)
10051006
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
@@ -1018,10 +1019,12 @@ sycl::event binary_contig_row_contig_matrix_broadcast_impl(
10181019
std::size_t n_groups = (n_elems + lws - 1) / lws;
10191020
auto gwsRange = sycl::range<1>(n_groups * lws);
10201021

1022+
using Impl =
1023+
BinaryContigRowContigMatrixBroadcastFunctorT<argT1, argT2, resT>;
1024+
10211025
cgh.parallel_for<class kernel_name<argT1, argT2, resT>>(
10221026
sycl::nd_range<1>(gwsRange, lwsRange),
1023-
BinaryContigRowContigMatrixBroadcastFunctorT<argT1, argT2, resT>(
1024-
padded_vec, mat, res, n_elems, n1));
1027+
Impl(padded_vec, mat, res, n_elems, n1));
10251028
});
10261029

10271030
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//=== common_detail.hpp - - *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2025 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines common code for elementwise tensor operations.
23+
//===---------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <cstddef>
27+
#include <vector>
28+
29+
#include <sycl/sycl.hpp>
30+
31+
namespace dpctl
32+
{
33+
namespace tensor
34+
{
35+
namespace kernels
36+
{
37+
namespace elementwise_detail
38+
{
39+
40+
template <typename T> class populate_padded_vec_krn;
41+
42+
template <typename T>
43+
sycl::event
44+
populate_padded_vector(sycl::queue &exec_q,
45+
const T *vec,
46+
std::size_t vec_sz,
47+
T *padded_vec,
48+
size_t padded_vec_sz,
49+
const std::vector<sycl::event> &dependent_events)
50+
{
51+
sycl::event populate_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
52+
// ensure vec contains actual data
53+
cgh.depends_on(dependent_events);
54+
55+
sycl::range<1> gRange{padded_vec_sz};
56+
57+
cgh.parallel_for<class populate_padded_vec_krn<T>>(
58+
gRange, [=](sycl::id<1> id) {
59+
std::size_t i = id[0];
60+
padded_vec[i] = vec[i % vec_sz];
61+
});
62+
});
63+
64+
return populate_padded_vec_ev;
65+
}
66+
67+
} // end of namespace elementwise_detail
68+
} // end of namespace kernels
69+
} // end of namespace tensor
70+
} // end of namespace dpctl

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
#include <cstddef>
2828
#include <cstdint>
2929
#include <stdexcept>
30+
3031
#include <sycl/sycl.hpp>
3132

3233
#include "kernels/alignment.hpp"
3334
#include "kernels/dpctl_tensor_types.hpp"
35+
#include "kernels/elementwise_functions/common_detail.hpp"
3436
#include "utils/offset_utils.hpp"
3537
#include "utils/sycl_alloc_utils.hpp"
3638
#include "utils/sycl_utils.hpp"
@@ -337,23 +339,26 @@ binary_inplace_contig_impl(sycl::queue &exec_q,
337339
{
338340
constexpr bool enable_sg_loadstore = true;
339341
using KernelName = kernel_name<argTy, resTy, vec_sz, n_vecs>;
342+
using Impl =
343+
BinaryInplaceContigFunctorT<argTy, resTy, vec_sz, n_vecs,
344+
enable_sg_loadstore>;
345+
340346
cgh.parallel_for<KernelName>(
341347
sycl::nd_range<1>(gws_range, lws_range),
342-
BinaryInplaceContigFunctorT<argTy, resTy, vec_sz, n_vecs,
343-
enable_sg_loadstore>(arg_tp, res_tp,
344-
nelems));
348+
Impl(arg_tp, res_tp, nelems));
345349
}
346350
else {
347351
constexpr bool disable_sg_loadstore = true;
348352
using InnerKernelName = kernel_name<argTy, resTy, vec_sz, n_vecs>;
349353
using KernelName =
350354
disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
355+
using Impl =
356+
BinaryInplaceContigFunctorT<argTy, resTy, vec_sz, n_vecs,
357+
disable_sg_loadstore>;
351358

352359
cgh.parallel_for<KernelName>(
353360
sycl::nd_range<1>(gws_range, lws_range),
354-
BinaryInplaceContigFunctorT<argTy, resTy, vec_sz, n_vecs,
355-
disable_sg_loadstore>(
356-
arg_tp, res_tp, nelems));
361+
Impl(arg_tp, res_tp, nelems));
357362
}
358363
});
359364
return comp_ev;
@@ -389,9 +394,10 @@ binary_inplace_strided_impl(sycl::queue &exec_q,
389394
const argTy *arg_tp = reinterpret_cast<const argTy *>(rhs_p);
390395
resTy *res_tp = reinterpret_cast<resTy *>(lhs_p);
391396

397+
using Impl = BinaryInplaceStridedFunctorT<argTy, resTy, IndexerT>;
398+
392399
cgh.parallel_for<kernel_name<argTy, resTy, IndexerT>>(
393-
{nelems}, BinaryInplaceStridedFunctorT<argTy, resTy, IndexerT>(
394-
arg_tp, res_tp, indexer));
400+
{nelems}, Impl(arg_tp, res_tp, indexer));
395401
});
396402
return comp_ev;
397403
}
@@ -428,13 +434,9 @@ sycl::event binary_inplace_row_matrix_broadcast_impl(
428434
exec_q);
429435
argT *padded_vec = padded_vec_owner.get();
430436

431-
sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
432-
cgh.depends_on(depends); // ensure vec contains actual data
433-
cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) {
434-
auto i = id[0];
435-
padded_vec[i] = vec[i % n1];
436-
});
437-
});
437+
sycl::event make_padded_vec_ev =
438+
dpctl::tensor::kernels::elementwise_detail::populate_padded_vector<
439+
argT>(exec_q, vec, n1, padded_vec, n1_padded, depends);
438440

439441
// sub-group spans work-items [I, I + sgSize)
440442
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
@@ -453,10 +455,11 @@ sycl::event binary_inplace_row_matrix_broadcast_impl(
453455
std::size_t n_groups = (n_elems + lws - 1) / lws;
454456
auto gwsRange = sycl::range<1>(n_groups * lws);
455457

458+
using Impl = BinaryInplaceRowMatrixBroadcastFunctorT<argT, resT>;
459+
456460
cgh.parallel_for<class kernel_name<argT, resT>>(
457461
sycl::nd_range<1>(gwsRange, lwsRange),
458-
BinaryInplaceRowMatrixBroadcastFunctorT<argT, resT>(padded_vec, mat,
459-
n_elems, n1));
462+
Impl(padded_vec, mat, n_elems, n1));
460463
});
461464

462465
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(

0 commit comments

Comments
 (0)