Skip to content

Commit 4439bf7

Browse files
Renamed struct to better reflect its purpose
1 parent 5f659fe commit 4439bf7

File tree

7 files changed

+180
-139
lines changed

7 files changed

+180
-139
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,21 @@ template <typename T> struct AbsOutputType
6464
{
6565
using value_type = typename std::disjunction< // disjunction is C++17
6666
// feature, supported by DPC++
67-
td_ns::TypeMapEntry<T, bool>,
68-
td_ns::TypeMapEntry<T, std::uint8_t>,
69-
td_ns::TypeMapEntry<T, std::uint16_t>,
70-
td_ns::TypeMapEntry<T, std::uint32_t>,
71-
td_ns::TypeMapEntry<T, std::uint64_t>,
72-
td_ns::TypeMapEntry<T, std::int8_t>,
73-
td_ns::TypeMapEntry<T, std::int16_t>,
74-
td_ns::TypeMapEntry<T, std::int32_t>,
75-
td_ns::TypeMapEntry<T, std::int64_t>,
76-
td_ns::TypeMapEntry<T, sycl::half>,
77-
td_ns::TypeMapEntry<T, float>,
78-
td_ns::TypeMapEntry<T, double>,
79-
td_ns::TypeMapEntry<T, std::complex<float>, float>,
80-
td_ns::TypeMapEntry<T, std::complex<double>, double>,
81-
td_ns::DefaultEntry<void>>::result_type;
67+
td_ns::TypeMapResultEntry<T, bool>,
68+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
69+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
70+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
71+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
72+
td_ns::TypeMapResultEntry<T, std::int8_t>,
73+
td_ns::TypeMapResultEntry<T, std::int16_t>,
74+
td_ns::TypeMapResultEntry<T, std::int32_t>,
75+
td_ns::TypeMapResultEntry<T, std::int64_t>,
76+
td_ns::TypeMapResultEntry<T, sycl::half>,
77+
td_ns::TypeMapResultEntry<T, float>,
78+
td_ns::TypeMapResultEntry<T, double>,
79+
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
80+
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
81+
td_ns::DefaultResultEntry<void>>::result_type;
8282
};
8383

8484
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -80,58 +80,65 @@ template <typename T1, typename T2> struct AddOutputType
8080
{
8181
using value_type = typename std::disjunction< // disjunction is C++17
8282
// feature, supported by DPC++
83-
td_ns::BinaryTypeMapEntry<T1, bool, T2, bool, bool>,
84-
td_ns::BinaryTypeMapEntry<T1,
85-
std::uint8_t,
86-
T2,
87-
std::uint8_t,
88-
std::uint8_t>,
89-
td_ns::
90-
BinaryTypeMapEntry<T1, std::int8_t, T2, std::int8_t, std::int8_t>,
91-
td_ns::BinaryTypeMapEntry<T1,
92-
std::uint16_t,
93-
T2,
94-
std::uint16_t,
95-
std::uint16_t>,
96-
td_ns::BinaryTypeMapEntry<T1,
97-
std::int16_t,
98-
T2,
99-
std::int16_t,
100-
std::int16_t>,
101-
td_ns::BinaryTypeMapEntry<T1,
102-
std::uint32_t,
103-
T2,
104-
std::uint32_t,
105-
std::uint32_t>,
106-
td_ns::BinaryTypeMapEntry<T1,
107-
std::int32_t,
108-
T2,
109-
std::int32_t,
110-
std::int32_t>,
111-
td_ns::BinaryTypeMapEntry<T1,
112-
std::uint64_t,
113-
T2,
114-
std::uint64_t,
115-
std::uint64_t>,
116-
td_ns::BinaryTypeMapEntry<T1,
117-
std::int64_t,
118-
T2,
119-
std::int64_t,
120-
std::int64_t>,
121-
td_ns::BinaryTypeMapEntry<T1, sycl::half, T2, sycl::half, sycl::half>,
122-
td_ns::BinaryTypeMapEntry<T1, float, T2, float, float>,
123-
td_ns::BinaryTypeMapEntry<T1, double, T2, double, double>,
124-
td_ns::BinaryTypeMapEntry<T1,
125-
std::complex<float>,
126-
T2,
127-
std::complex<float>,
128-
std::complex<float>>,
129-
td_ns::BinaryTypeMapEntry<T1,
130-
std::complex<double>,
131-
T2,
132-
std::complex<double>,
133-
std::complex<double>>,
134-
td_ns::DefaultEntry<void>>::result_type;
83+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
84+
td_ns::BinaryTypeMapResultEntry<T1,
85+
std::uint8_t,
86+
T2,
87+
std::uint8_t,
88+
std::uint8_t>,
89+
td_ns::BinaryTypeMapResultEntry<T1,
90+
std::int8_t,
91+
T2,
92+
std::int8_t,
93+
std::int8_t>,
94+
td_ns::BinaryTypeMapResultEntry<T1,
95+
std::uint16_t,
96+
T2,
97+
std::uint16_t,
98+
std::uint16_t>,
99+
td_ns::BinaryTypeMapResultEntry<T1,
100+
std::int16_t,
101+
T2,
102+
std::int16_t,
103+
std::int16_t>,
104+
td_ns::BinaryTypeMapResultEntry<T1,
105+
std::uint32_t,
106+
T2,
107+
std::uint32_t,
108+
std::uint32_t>,
109+
td_ns::BinaryTypeMapResultEntry<T1,
110+
std::int32_t,
111+
T2,
112+
std::int32_t,
113+
std::int32_t>,
114+
td_ns::BinaryTypeMapResultEntry<T1,
115+
std::uint64_t,
116+
T2,
117+
std::uint64_t,
118+
std::uint64_t>,
119+
td_ns::BinaryTypeMapResultEntry<T1,
120+
std::int64_t,
121+
T2,
122+
std::int64_t,
123+
std::int64_t>,
124+
td_ns::BinaryTypeMapResultEntry<T1,
125+
sycl::half,
126+
T2,
127+
sycl::half,
128+
sycl::half>,
129+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
130+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
131+
td_ns::BinaryTypeMapResultEntry<T1,
132+
std::complex<float>,
133+
T2,
134+
std::complex<float>,
135+
std::complex<float>>,
136+
td_ns::BinaryTypeMapResultEntry<T1,
137+
std::complex<double>,
138+
T2,
139+
std::complex<double>,
140+
std::complex<double>>,
141+
td_ns::DefaultResultEntry<void>>::result_type;
135142
};
136143

137144
template <typename argT1,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ template <typename T> struct CosOutputType
6060
{
6161
using value_type = typename std::disjunction< // disjunction is C++17
6262
// feature, supported by DPC++
63-
td_ns::TypeMapEntry<T, sycl::half, sycl::half>,
64-
td_ns::TypeMapEntry<T, float, float>,
65-
td_ns::TypeMapEntry<T, double, double>,
66-
td_ns::TypeMapEntry<T, std::complex<float>, std::complex<float>>,
67-
td_ns::TypeMapEntry<T, std::complex<double>, std::complex<double>>,
68-
td_ns::DefaultEntry<void>>::result_type;
63+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
64+
td_ns::TypeMapResultEntry<T, float, float>,
65+
td_ns::TypeMapResultEntry<T, double, double>,
66+
td_ns::TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
67+
td_ns::
68+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
69+
td_ns::DefaultResultEntry<void>>::result_type;
6970
};
7071

7172
typedef sycl::event (*cos_contig_impl_fn_ptr_t)(

dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,45 @@ template <typename T1, typename T2> struct EqualOutputType
8383
{
8484
using value_type = typename std::disjunction< // disjunction is C++17
8585
// feature, supported by DPC++
86-
td_ns::BinaryTypeMapEntry<T1, bool, T2, bool, bool>,
87-
td_ns::BinaryTypeMapEntry<T1, std::uint8_t, T2, std::uint8_t, bool>,
88-
td_ns::BinaryTypeMapEntry<T1, std::int8_t, T2, std::int8_t, bool>,
89-
td_ns::BinaryTypeMapEntry<T1, std::uint16_t, T2, std::uint16_t, bool>,
90-
td_ns::BinaryTypeMapEntry<T1, std::int16_t, T2, std::int16_t, bool>,
91-
td_ns::BinaryTypeMapEntry<T1, std::uint32_t, T2, std::uint32_t, bool>,
92-
td_ns::BinaryTypeMapEntry<T1, std::int32_t, T2, std::int32_t, bool>,
93-
td_ns::BinaryTypeMapEntry<T1, std::uint64_t, T2, std::uint64_t, bool>,
94-
td_ns::BinaryTypeMapEntry<T1, std::int64_t, T2, std::int64_t, bool>,
95-
td_ns::BinaryTypeMapEntry<T1, sycl::half, T2, sycl::half, bool>,
96-
td_ns::BinaryTypeMapEntry<T1, float, T2, float, bool>,
97-
td_ns::BinaryTypeMapEntry<T1, double, T2, double, bool>,
98-
td_ns::BinaryTypeMapEntry<T1,
99-
std::complex<float>,
100-
T2,
101-
std::complex<float>,
102-
bool>,
103-
td_ns::BinaryTypeMapEntry<T1,
104-
std::complex<double>,
105-
T2,
106-
std::complex<double>,
107-
bool>,
108-
td_ns::DefaultEntry<void>>::result_type;
86+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
87+
td_ns::
88+
BinaryTypeMapResultEntry<T1, std::uint8_t, T2, std::uint8_t, bool>,
89+
td_ns::BinaryTypeMapResultEntry<T1, std::int8_t, T2, std::int8_t, bool>,
90+
td_ns::BinaryTypeMapResultEntry<T1,
91+
std::uint16_t,
92+
T2,
93+
std::uint16_t,
94+
bool>,
95+
td_ns::
96+
BinaryTypeMapResultEntry<T1, std::int16_t, T2, std::int16_t, bool>,
97+
td_ns::BinaryTypeMapResultEntry<T1,
98+
std::uint32_t,
99+
T2,
100+
std::uint32_t,
101+
bool>,
102+
td_ns::
103+
BinaryTypeMapResultEntry<T1, std::int32_t, T2, std::int32_t, bool>,
104+
td_ns::BinaryTypeMapResultEntry<T1,
105+
std::uint64_t,
106+
T2,
107+
std::uint64_t,
108+
bool>,
109+
td_ns::
110+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
111+
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
112+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
113+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
114+
td_ns::BinaryTypeMapResultEntry<T1,
115+
std::complex<float>,
116+
T2,
117+
std::complex<float>,
118+
bool>,
119+
td_ns::BinaryTypeMapResultEntry<T1,
120+
std::complex<double>,
121+
T2,
122+
std::complex<double>,
123+
bool>,
124+
td_ns::DefaultResultEntry<void>>::result_type;
109125
};
110126

111127
template <typename argT1,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ template <typename T> struct SqrtOutputType
6060
{
6161
using value_type = typename std::disjunction< // disjunction is C++17
6262
// feature, supported by DPC++
63-
td_ns::TypeMapEntry<T, sycl::half, sycl::half>,
64-
td_ns::TypeMapEntry<T, float, float>,
65-
td_ns::TypeMapEntry<T, double, double>,
66-
td_ns::TypeMapEntry<T, std::complex<float>, std::complex<float>>,
67-
td_ns::TypeMapEntry<T, std::complex<double>, std::complex<double>>,
68-
td_ns::DefaultEntry<void>>::result_type;
63+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
64+
td_ns::TypeMapResultEntry<T, float, float>,
65+
td_ns::TypeMapResultEntry<T, double, double>,
66+
td_ns::TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
67+
td_ns::
68+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
69+
td_ns::DefaultResultEntry<void>>::result_type;
6970
};
7071

7172
typedef sycl::event (*sqrt_contig_impl_fn_ptr_t)(

dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -81,40 +81,44 @@ template <typename T1, typename T2> struct TrueDivideOutputType
8181
{
8282
using value_type = typename std::disjunction< // disjunction is C++17
8383
// feature, supported by DPC++
84-
td_ns::BinaryTypeMapEntry<T1, sycl::half, T2, sycl::half, sycl::half>,
85-
td_ns::BinaryTypeMapEntry<T1, float, T2, float, float>,
86-
td_ns::BinaryTypeMapEntry<T1, double, T2, double, double>,
87-
td_ns::BinaryTypeMapEntry<T1,
88-
std::complex<float>,
89-
T2,
90-
std::complex<float>,
91-
std::complex<float>>,
92-
td_ns::BinaryTypeMapEntry<T1,
93-
std::complex<float>,
94-
T2,
95-
float,
96-
std::complex<float>>,
97-
td_ns::BinaryTypeMapEntry<T1,
98-
float,
99-
T2,
100-
std::complex<float>,
101-
std::complex<float>>,
102-
td_ns::BinaryTypeMapEntry<T1,
103-
std::complex<double>,
104-
T2,
105-
std::complex<double>,
106-
std::complex<double>>,
107-
td_ns::BinaryTypeMapEntry<T1,
108-
double,
109-
T2,
110-
std::complex<double>,
111-
std::complex<double>>,
112-
td_ns::BinaryTypeMapEntry<T1,
113-
std::complex<double>,
114-
T2,
115-
double,
116-
std::complex<double>>,
117-
td_ns::DefaultEntry<void>>::result_type;
84+
td_ns::BinaryTypeMapResultEntry<T1,
85+
sycl::half,
86+
T2,
87+
sycl::half,
88+
sycl::half>,
89+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
90+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
91+
td_ns::BinaryTypeMapResultEntry<T1,
92+
std::complex<float>,
93+
T2,
94+
std::complex<float>,
95+
std::complex<float>>,
96+
td_ns::BinaryTypeMapResultEntry<T1,
97+
std::complex<float>,
98+
T2,
99+
float,
100+
std::complex<float>>,
101+
td_ns::BinaryTypeMapResultEntry<T1,
102+
float,
103+
T2,
104+
std::complex<float>,
105+
std::complex<float>>,
106+
td_ns::BinaryTypeMapResultEntry<T1,
107+
std::complex<double>,
108+
T2,
109+
std::complex<double>,
110+
std::complex<double>>,
111+
td_ns::BinaryTypeMapResultEntry<T1,
112+
double,
113+
T2,
114+
std::complex<double>,
115+
std::complex<double>>,
116+
td_ns::BinaryTypeMapResultEntry<T1,
117+
std::complex<double>,
118+
T2,
119+
double,
120+
std::complex<double>>,
121+
td_ns::DefaultResultEntry<void>>::result_type;
118122
};
119123

120124
template <typename argT1,

dpctl/tensor/libtensor/include/utils/type_dispatch.hpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ struct usm_ndarray_types
252252

253253
/*! @brief struct to define result_type typename for Ty == ArgTy */
254254
template <typename Ty, typename ArgTy, typename ResTy = ArgTy>
255-
struct TypeMapEntry : std::bool_constant<std::is_same_v<Ty, ArgTy>>
255+
struct TypeMapResultEntry : std::bool_constant<std::is_same_v<Ty, ArgTy>>
256256
{
257257
using result_type = ResTy;
258258
};
@@ -264,15 +264,15 @@ template <typename Ty1,
264264
typename Ty2,
265265
typename ArgTy2,
266266
typename ResTy>
267-
struct BinaryTypeMapEntry
267+
struct BinaryTypeMapResultEntry
268268
: std::bool_constant<std::conjunction_v<std::is_same<Ty1, ArgTy1>,
269269
std::is_same<Ty2, ArgTy2>>>
270270
{
271271
using result_type = ResTy;
272272
};
273273

274274
/*! @brief fall-through struct with specified result_type, usually void */
275-
template <typename Ty = void> struct DefaultEntry : std::true_type
275+
template <typename Ty = void> struct DefaultResultEntry : std::true_type
276276
{
277277
using result_type = Ty;
278278
};
@@ -368,6 +368,18 @@ template <typename FunPtrT> struct NullPtrTable
368368
value_type val;
369369
};
370370

371+
template <typename Ty1, typename ArgTy, typename Ty2, typename outTy>
372+
struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty1, ArgTy> &&
373+
std::is_same_v<Ty2, outTy>>
374+
{
375+
static constexpr bool is_defined = true;
376+
};
377+
378+
struct NotDefinedEntry : std::true_type
379+
{
380+
static constexpr bool is_defined = false;
381+
};
382+
371383
} // namespace type_dispatch
372384

373385
} // namespace tensor

0 commit comments

Comments
 (0)