Skip to content

Commit 6636103

Browse files
authored
[ESIMD] Fix saturation argument of DPAS (#6647)
* [ESIMD] Fix saturation argument of DPAS The template argument for saturation was declared/used such a way that any type could be passed to it, which would cause enforcement of saturation when not intended. In even worse scenarios the DPAS call with 3 simd arguments was recognized as DPAS with 2 simd arguments + saturation argument: dpas(src0,src1,src2) was treated as dpas(src1,src2,sat), which caused totally incorrect behavior at runtime. Also, this patch fixes the incorrect detection of ops_per_channel for tfloat32 type on HOST. Signed-off-by: Vyacheslav N Klochkov <vyacheslav.n.klochkov@intel.com>
1 parent 1660a61 commit 6636103

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

sycl/include/sycl/ext/intel/esimd/common.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ using SurfaceIndex = unsigned int;
6565

6666
namespace detail {
6767

68+
template <typename T>
69+
struct is_saturation_tag {
70+
static constexpr bool value =
71+
std::is_same_v<T, __ESIMD_NS::saturation_on_tag> ||
72+
std::is_same_v<T, __ESIMD_NS::saturation_off_tag>;
73+
};
74+
75+
template <class T>
76+
inline constexpr bool is_saturation_tag_v = is_saturation_tag<T>::value;
77+
6878
/// Check if a given 32 bit positive integer is a power of 2 at compile time.
6979
ESIMD_INLINE constexpr bool isPowerOf2(unsigned int n) {
7080
return (n & (n - 1)) == 0;

sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -451,24 +451,15 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
451451
__ESIMD_EMU_DNS::SetSatur<T2,
452452
__ESIMD_EMU_DNS::is_inttype<RT>::value>::set();
453453

454-
constexpr __ESIMD_NS::uint ops_per_chan =
455-
src1_precision == __ESIMD_ENS::argument_type::BF16 ||
456-
src1_precision == __ESIMD_ENS::argument_type::FP16 ||
457-
src2_precision == __ESIMD_ENS::argument_type::BF16 ||
458-
src2_precision == __ESIMD_ENS::argument_type::FP16
459-
? 2
460-
: src1_precision == __ESIMD_ENS::argument_type::S8 ||
461-
src1_precision == __ESIMD_ENS::argument_type::U8 ||
462-
src2_precision == __ESIMD_ENS::argument_type::S8 ||
463-
src2_precision == __ESIMD_ENS::argument_type::U8
464-
? 4
465-
: 8;
466-
467454
__ESIMD_NS::uint V = 0, U = 0, k = 0, temp = 0, src1_ops_per_dword = 0, p = 0;
468455

469456
constexpr auto src1_el_bits = __esimd_dpas_bits_precision(src1_precision);
470457
constexpr auto src2_el_bits = __esimd_dpas_bits_precision(src2_precision);
471458

459+
constexpr auto max_el_bits = std::max(src1_el_bits, src2_el_bits);
460+
constexpr __ESIMD_NS::uint ops_per_chan =
461+
std::min(32 / max_el_bits, static_cast<__ESIMD_NS::uint>(8));
462+
472463
uint32_t src1_signed =
473464
src1_precision == __ESIMD_ENS::argument_type::S2 ||
474465
src1_precision == __ESIMD_ENS::argument_type::S4 ||

sycl/include/sycl/ext/intel/experimental/esimd/math.hpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,7 +1761,8 @@ template <argument_type src1_precision, argument_type src2_precision,
17611761
typename Sat = __ESIMD_NS::saturation_off_tag>
17621762
__ESIMD_API __ESIMD_NS::simd<T, N>
17631763
dpas(__ESIMD_NS::simd<T0, N> src0, __ESIMD_NS::simd<T1, N1> src1,
1764-
__ESIMD_NS::simd<T2, N2> src2, Sat sat = {}) {
1764+
__ESIMD_NS::simd<T2, N2> src2,
1765+
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
17651766
// types: dst, src0, src1, src2
17661767
// ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2
17671768
constexpr bool check_integer =
@@ -1894,7 +1895,8 @@ template <argument_type src1_precision, argument_type src2_precision,
18941895
typename Sat = __ESIMD_NS::saturation_off_tag>
18951896
__ESIMD_API __ESIMD_NS::simd<T, N>
18961897
dpas(__ESIMD_NS::simd<T, N> src0, __ESIMD_NS::simd<T1, N1> src1,
1897-
__ESIMD_NS::simd<T2, N2> src2, Sat sat = {}) {
1898+
__ESIMD_NS::simd<T2, N2> src2,
1899+
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
18981900
return dpas<src1_precision, src2_precision, T, systolic_depth, repeat_count>(
18991901
src0, src1, src2, sat);
19001902
}
@@ -1911,9 +1913,9 @@ template <argument_type src1_precision, argument_type src2_precision,
19111913
int systolic_depth, int repeat_count, typename T, typename T1,
19121914
typename T2, int N, int N1, int N2,
19131915
typename Sat = __ESIMD_NS::saturation_off_tag>
1914-
__ESIMD_API __ESIMD_NS::simd<T, N> dpas(__ESIMD_NS::simd<T1, N1> src1,
1915-
__ESIMD_NS::simd<T2, N2> src2,
1916-
Sat sat = {}) {
1916+
__ESIMD_API __ESIMD_NS::simd<T, N>
1917+
dpas(__ESIMD_NS::simd<T1, N1> src1, __ESIMD_NS::simd<T2, N2> src2,
1918+
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
19171919

19181920
static_assert(__ESIMD_DNS::is_fp_or_dword_type<T>::value,
19191921
"Dst must be FP or DWORD type");
@@ -1976,7 +1978,8 @@ template <argument_type src1_precision, argument_type src2_precision,
19761978
typename Sat = __ESIMD_NS::saturation_off_tag>
19771979
__ESIMD_API __ESIMD_NS::simd<T, N>
19781980
dpasw(__ESIMD_NS::simd<T, N> src0, __ESIMD_NS::simd<T1, N1> src1,
1979-
__ESIMD_NS::simd<T2, N2> src2, Sat sat = {}) {
1981+
__ESIMD_NS::simd<T2, N2> src2,
1982+
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
19801983
constexpr bool is_4xhf =
19811984
std::is_same_v<T, __ESIMD_DNS::__raw_t<sycl::half>> &&
19821985
(src1_precision == src2_precision) &&
@@ -2048,9 +2051,9 @@ template <argument_type src1_precision, argument_type src2_precision,
20482051
int systolic_depth, int repeat_count, typename T, typename T1,
20492052
typename T2, int N, int N1, int N2,
20502053
typename Sat = __ESIMD_NS::saturation_off_tag>
2051-
__ESIMD_API __ESIMD_NS::simd<T, N> dpasw2(__ESIMD_NS::simd<T1, N1> src1,
2052-
__ESIMD_NS::simd<T2, N2> src2,
2053-
Sat sat = {}) {
2054+
__ESIMD_API __ESIMD_NS::simd<T, N>
2055+
dpasw2(__ESIMD_NS::simd<T1, N1> src1, __ESIMD_NS::simd<T2, N2> src2,
2056+
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
20542057
constexpr bool is_4xhf =
20552058
std::is_same_v<T, __ESIMD_DNS::__raw_t<sycl::half>> &&
20562059
src1_precision == src2_precision && src1_precision == argument_type::FP16;

0 commit comments

Comments
 (0)