diff --git a/.github/workflows/c-cpp.yml b/.github/workflows/c-cpp.yml index 5f6d480..73f0a0a 100644 --- a/.github/workflows/c-cpp.yml +++ b/.github/workflows/c-cpp.yml @@ -208,6 +208,9 @@ jobs: - name: Run test suite on SPR run: sde -spr -- ./builddir/testexe + - name: Run ICL fp16 tests + # Note: This filters for the _Float16 tests based on the number assigned to it, which could change in the future + run: sde -icx -- ./builddir/testexe --gtest_filter="*/simdsort/2*" SKX-SKL-openmp: diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index eeb7b2b..6bbad2c 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -51,4 +51,31 @@ namespace avx512 { x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); } } // namespace avx512 +namespace fp16_icl { +#ifdef __FLT16_MAX__ + template <> + void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) + { + x86simdsortStatic::qsort(arr, size, hasnan, descending); + } + template <> + void qselect(_Float16 *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) + { + x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); + } + template <> + void partial_qsort(_Float16 *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) + { + x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); + } +#endif +} // namespace fp16_icl } // namespace xss diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 6cf261a..a9ded64 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -4,165 +4,60 @@ #include #include +#define DECLAREALLFUNCS(name) \ + namespace name { \ + template \ + XSS_HIDE_SYMBOL void qsort(T *arr, \ + size_t arrsize, \ + bool hasnan = false, \ + bool descending = false); \ + template \ + XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, \ + T2 *val, \ + size_t arrsize, \ + bool hasnan = false, \ + bool descending = false); \ + template \ + XSS_HIDE_SYMBOL void qselect(T *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan = false, \ + bool descending = false); \ + template \ + XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, \ + T2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan = false, \ + bool descending = false); \ + template \ + XSS_HIDE_SYMBOL void partial_qsort(T *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan = false, \ + bool descending = false); \ + template \ + XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, \ + T2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan = false, \ + bool descending = false); \ + template \ + XSS_HIDE_SYMBOL std::vector argsort(T *arr, \ + size_t arrsize, \ + bool hasnan = false, \ + bool descending = false); \ + template \ + XSS_HIDE_SYMBOL std::vector \ + argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); \ + } + namespace xss { -namespace avx512 { - // quicksort - template - XSS_HIDE_SYMBOL void - qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); - // key-value quicksort - template - XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, - T2 *val, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // quickselect - template - XSS_HIDE_SYMBOL void qselect(T *arr, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // key-value select - template - XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, - T2 *val, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // partial sort - template - XSS_HIDE_SYMBOL void partial_qsort(T *arr, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // key-value partial sort - template - XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, - T2 *val, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // argsort - template - XSS_HIDE_SYMBOL std::vector argsort(T *arr, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // argselect - template - XSS_HIDE_SYMBOL std::vector - argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); -} // namespace avx512 -namespace avx2 { - // quicksort - template - XSS_HIDE_SYMBOL void - qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); - // key-value quicksort - template - XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, - T2 *val, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // quickselect - template - XSS_HIDE_SYMBOL void qselect(T *arr, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // key-value select - template - XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, - T2 *val, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // partial sort - template - XSS_HIDE_SYMBOL void partial_qsort(T *arr, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // key-value partial sort - template - XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, - T2 *val, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // argsort - template - XSS_HIDE_SYMBOL std::vector argsort(T *arr, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // argselect - template - XSS_HIDE_SYMBOL std::vector - argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); -} // namespace avx2 -namespace scalar { - // quicksort - template - XSS_HIDE_SYMBOL void - qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); - // key-value quicksort - template - XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, - T2 *val, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // quickselect - template - XSS_HIDE_SYMBOL void qselect(T *arr, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // key-value select - template - XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, - T2 *val, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // partial sort - template - XSS_HIDE_SYMBOL void partial_qsort(T *arr, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // key-value partial sort - template - XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, - T2 *val, - size_t k, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // argsort - template - XSS_HIDE_SYMBOL std::vector argsort(T *arr, - size_t arrsize, - bool hasnan = false, - bool descending = false); - // argselect - template - XSS_HIDE_SYMBOL std::vector - argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); -} // namespace scalar +DECLAREALLFUNCS(avx512) +DECLAREALLFUNCS(avx2) +DECLAREALLFUNCS(scalar) +DECLAREALLFUNCS(fp16_spr) +DECLAREALLFUNCS(fp16_icl) } // namespace xss #endif diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index b8069d2..7587640 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -3,7 +3,7 @@ #include "x86simdsort-internal.h" namespace xss { -namespace avx512 { +namespace fp16_spr { template <> void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) { @@ -27,5 +27,5 @@ namespace avx512 { { x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); } -} // namespace avx512 +} // namespace fp16_spr } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index a5c8d62..8ef9aad 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -108,6 +108,17 @@ namespace x86simdsort { return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \ } +/* simple constexpr function as a way around having #ifdef __FLT16_MAX__ block + * within the DISPATCH macro */ +template +constexpr bool IS_TYPE_FLOAT16() +{ +#ifdef __FLT16_MAX__ + if constexpr (std::is_same_v) { return true; } +#endif + return false; +} + /* runtime dispatch mechanism */ #define DISPATCH(func, TYPE, ISA) \ DECLARE_INTERNAL_##func(TYPE) static __attribute__((constructor)) void \ @@ -118,7 +129,24 @@ namespace x86simdsort { std::string_view preferred_cpu = find_preferred_cpu(ISA); \ if constexpr (dispatch_requested("avx512", ISA)) { \ if (preferred_cpu.find("avx512") != std::string_view::npos) { \ - CAT(CAT(internal_, func), TYPE) = &xss::avx512::func; \ + if constexpr (IS_TYPE_FLOAT16()) { \ + if (preferred_cpu.find("avx512_spr") \ + != std::string_view::npos) { \ + CAT(CAT(internal_, func), TYPE) \ + = &xss::fp16_spr::func; \ + return; \ + } \ + if (preferred_cpu.find("avx512_icl") \ + != std::string_view::npos) { \ + CAT(CAT(internal_, func), TYPE) \ + = &xss::fp16_icl::func; \ + return; \ + } \ + } \ + else { \ + CAT(CAT(internal_, func), TYPE) \ + = &xss::avx512::func; \ + } \ return; \ } \ } \ @@ -137,9 +165,9 @@ namespace x86simdsort { } #ifdef __FLT16_MAX__ -DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr")) -DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr")) -DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl")) +DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr", "avx512_icl")) +DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl")) DISPATCH(argsort, _Float16, ISA_LIST("none")) DISPATCH(argselect, _Float16, ISA_LIST("none")) #endif diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index e05027d..6dbe24d 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -9,10 +9,6 @@ #include "avx512-16bit-common.h" -struct float16 { - uint16_t val; -}; - template <> struct zmm_vector { using type_t = uint16_t; @@ -545,10 +541,45 @@ replace_nan_with_inf>(uint16_t *arr, arrsize_t arrsize) return nan_count; } -template <> -X86_SIMD_SORT_INLINE_ONLY bool is_a_nan(uint16_t elem) +template +[[maybe_unused]] X86_SIMD_SORT_INLINE void +avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize) { - return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0); + using T = uint16_t; + using vtype = zmm_vector; + +#ifdef XSS_COMPILE_OPENMP + bool use_parallel = arrsize > 100000; + + if (use_parallel) { + // This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system + constexpr int thread_limit = 8; + int thread_count = std::min(thread_limit, omp_get_max_threads()); + arrsize_t task_threshold = std::max((arrsize_t)100000, arrsize / 100); + + // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ + // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems + // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays +#pragma omp parallel num_threads(thread_count) +#pragma omp single + qsort_(arr, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize), + task_threshold); + } + else { + qsort_(arr, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize), + std::numeric_limits::max()); + } +#pragma omp taskwait +#else + qsort_( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); +#endif } [[maybe_unused]] X86_SIMD_SORT_INLINE void @@ -559,20 +590,16 @@ avx512_qsort_fp16(uint16_t *arr, { using vtype = zmm_vector; - // TODO multithreading support here if (arrsize > 1) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf, uint16_t>( - arr, arrsize); + nan_count = replace_nan_with_inf(arr, arrsize); } if (descending) { - qsort_, uint16_t>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); + avx512_qsort_fp16_helper>(arr, arrsize); } else { - qsort_, uint16_t>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); + avx512_qsort_fp16_helper>(arr, arrsize); } replace_inf_with_nan(arr, arrsize, nan_count, descending); } @@ -592,26 +619,37 @@ avx512_qselect_fp16(uint16_t *arr, { using vtype = zmm_vector; - arrsize_t indx_last_elem = arrsize - 1; + // Exit early if no work would be done + if (arrsize <= 1) return; + + arrsize_t index_first_elem = 0; + arrsize_t index_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (descending) { + index_first_elem = move_nans_to_start_of_array(arr, arrsize); + } + else { + index_last_elem = move_nans_to_end_of_array(arr, arrsize); + } } - if (indx_last_elem >= k) { + + if (index_first_elem <= k && index_last_elem >= k) { if (descending) { qselect_, uint16_t>( arr, k, - 0, - indx_last_elem, - 2 * (arrsize_t)log2(indx_last_elem)); + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize)); } else { qselect_, uint16_t>( arr, k, - 0, - indx_last_elem, - 2 * (arrsize_t)log2(indx_last_elem)); + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize)); } } @@ -628,7 +666,8 @@ avx512_partial_qsort_fp16(uint16_t *arr, bool hasnan = false, bool descending = false) { + if (k == 0) return; avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending); - avx512_qsort_fp16(arr, k - 1, descending); + avx512_qsort_fp16(arr, k - 1, hasnan, descending); } #endif // AVX512_QSORT_16BIT diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 52dde7b..7a42268 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -173,6 +173,30 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, XSS_METHODS(avx512) +#if defined(__FLT16_MAX__) && defined(__AVX512BW__) \ + && defined(__AVX512VBMI2__) && !defined(__AVX512FP16__) +template <> +void x86simdsortStatic::qsort<_Float16>(_Float16 *arr, + size_t size, + bool hasnan, + bool descending) +{ + avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending); +} +template <> +void x86simdsortStatic::qselect<_Float16>( + _Float16 *arr, size_t k, size_t size, bool hasnan, bool descending) +{ + avx512_qselect_fp16((uint16_t *)arr, k, size, hasnan, descending); +} +template <> +void x86simdsortStatic::partial_qsort<_Float16>( + _Float16 *arr, size_t k, size_t size, bool hasnan, bool descending) +{ + avx512_partial_qsort_fp16((uint16_t *)arr, k, size, hasnan, descending); +} +#endif + #elif defined(__AVX512F__) #error "x86simdsort requires AVX512DQ and AVX512VL to be enabled in addition to AVX512F to use AVX512" diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index 7408571..a7c34c1 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -109,4 +109,8 @@ enum class simd_type : int { AVX2, AVX512 }; template X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b); +struct float16 { + uint16_t val; +}; + #endif // XSS_COMMON_INCLUDES diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index e3bf019..cf4a34a 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -45,6 +45,12 @@ bool is_a_nan(T elem) return std::isnan(elem); } +template <> +X86_SIMD_SORT_INLINE_ONLY bool is_a_nan(uint16_t elem) +{ + return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0); +} + template X86_SIMD_SORT_INLINE arrsize_t replace_nan_with_inf(T *arr, arrsize_t size) { @@ -110,7 +116,7 @@ X86_SIMD_SORT_INLINE void replace_inf_with_nan(type_t *arr, arr[ii] = xss::fp::quiet_NaN(); } else { - arr[ii] = 0xFFFF; + arr[ii] = 0x7c01; // std::quiet_nan } nan_count -= 1; } @@ -121,7 +127,7 @@ X86_SIMD_SORT_INLINE void replace_inf_with_nan(type_t *arr, arr[ii] = xss::fp::quiet_NaN(); } else { - arr[ii] = 0xFFFF; + arr[ii] = 0x7c01; // std::quiet_nan } nan_count -= 1; } diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 869ecea..f2ce3a6 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -305,8 +305,8 @@ REGISTER_TYPED_TEST_SUITE_P(simdsort, using QSortTestTypes = testing::Types= 13 +// support for _Float16 is incomplete in gcc-12, clang < 6 +#if __GNUC__ >= 13 || __clang_major__ >= 6 _Float16, #endif float,