|
| 1 | +//==------ invoke_simd.hpp - SYCL invoke_simd extension --*- C++ -*---------==// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +// ===--------------------------------------------------------------------=== // |
| 8 | +// Implemenation of the sycl_ext_oneapi_invoke_simd extension. |
| 9 | +// https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_invoke_simd.asciidoc |
| 10 | +// ===--------------------------------------------------------------------=== // |
| 11 | + |
| 12 | +#pragma once |
| 13 | + |
| 14 | +// SYCL extension macro definition as required by the SYCL specification. |
| 15 | +// 1 - Initial extension version. Base features are supported. |
| 16 | +#define SYCL_EXT_ONEAPI_INVOKE_SIMD 1 |
| 17 | + |
| 18 | +#include <sycl/ext/oneapi/experimental/uniform.hpp> |
| 19 | + |
| 20 | +#include <CL/sycl/sub_group.hpp> |
| 21 | +#include <std/experimental/simd.hpp> |
| 22 | +#include <sycl/detail/boost/mp11.hpp> |
| 23 | + |
| 24 | +#include <functional> |
| 25 | + |
| 26 | +// TODOs: |
| 27 | +// * (a) TODO bool translation in spmd2simd. |
| 28 | +// * (b) TODO enforce constness of a functor/lambda's () operator |
| 29 | +// * (c) TODO support lambdas and functors in BE |
| 30 | + |
| 31 | +/// Middle End - to - Back End interface to invoke explicit SIMD functions from |
| 32 | +/// SPMD SYCL context. Must not be used by user code. BEs are expected to |
| 33 | +/// recognize this intrinsic and transform the intrinsic call with a direct call |
| 34 | +/// to the SIMD target, as well as process SPMD arguments in the way described |
| 35 | +/// in the specification for `invoke_simd`. |
| 36 | +/// @tparam SpmdRet the return type. Can be `uniform<T>`. |
| 37 | +/// @tparam SimdCallee the type of the SIMD callee function (the "target"). Must |
| 38 | +/// be a function type (not lambda or functor). |
| 39 | +/// @tparam SpmdArgs The original SPMD arguments passed to the invoke_simd. |
| 40 | +template <bool IsFunc, class SpmdRet, class SimdCallee, class... SpmdArgs, |
| 41 | + class = std::enable_if_t<!IsFunc>> |
| 42 | +SYCL_EXTERNAL __regcall SpmdRet |
| 43 | +__builtin_invoke_simd(SimdCallee target, const void *obj, SpmdArgs... args) |
| 44 | +#ifdef __SYCL_DEVICE_ONLY__ |
| 45 | + ; |
| 46 | +#else |
| 47 | +{ |
| 48 | + // __builtin_invoke_simd is not supported on the host device yet |
| 49 | + throw sycl::feature_not_supported(); |
| 50 | +} |
| 51 | +#endif // __SYCL_DEVICE_ONLY__ |
| 52 | + |
| 53 | +template <bool IsFunc, class SpmdRet, class SimdCallee, class... SpmdArgs, |
| 54 | + class = std::enable_if_t<IsFunc>> |
| 55 | +SYCL_EXTERNAL __regcall SpmdRet __builtin_invoke_simd(SimdCallee target, |
| 56 | + SpmdArgs... args) |
| 57 | +#ifdef __SYCL_DEVICE_ONLY__ |
| 58 | + ; |
| 59 | +#else |
| 60 | +{ |
| 61 | + // __builtin_invoke_simd is not supported on the host device yet |
| 62 | + throw sycl::feature_not_supported(); |
| 63 | +} |
| 64 | +#endif // __SYCL_DEVICE_ONLY__ |
| 65 | + |
| 66 | +namespace sycl { |
| 67 | +namespace ext { |
| 68 | +namespace oneapi { |
| 69 | +namespace experimental { |
| 70 | + |
| 71 | +// --- Basic definitions prescribed by the spec. |
| 72 | +namespace simd_abi { |
| 73 | +// "Fixed-size simd width of N" ABI based on clang vectors - used as the ABI for |
| 74 | +// SIMD objects this implementation of invoke_simd spec is based on. |
| 75 | +template <class T, int N> |
| 76 | +using native_fixed_size = typename std::experimental::__simd_abi< |
| 77 | + std::experimental::_StorageKind::_VecExt, N>; |
| 78 | +} // namespace simd_abi |
| 79 | + |
| 80 | +// The SIMD object type, which is the generic std::experimental::simd type with |
| 81 | +// the native fixed size ABI. |
| 82 | +template <class T, int N> |
| 83 | +using simd = std::experimental::simd<T, simd_abi::native_fixed_size<T, N>>; |
| 84 | + |
| 85 | +// The SIMD mask object type. |
| 86 | +template <class T, int N> |
| 87 | +using simd_mask = |
| 88 | + std::experimental::simd_mask<T, simd_abi::native_fixed_size<T, N>>; |
| 89 | + |
| 90 | +// --- Helpers |
| 91 | +namespace detail { |
| 92 | + |
| 93 | +namespace __MP11_NS = sycl::detail::boost::mp11; |
| 94 | + |
| 95 | +// This structure performs the SPMD-to-SIMD parameter type conversion as defined |
| 96 | +// by the spec. |
| 97 | +template <class T, int N, class = void> struct spmd2simd; |
| 98 | +// * `uniform<T>` converts to `T` |
| 99 | +template <class T, int N> struct spmd2simd<uniform<T>, N> { |
| 100 | + using type = T; |
| 101 | +}; |
| 102 | +// * tuple of types converts to tuple of converted tuple element types. |
| 103 | +template <class... T, int N> struct spmd2simd<std::tuple<T...>, N> { |
| 104 | + using type = std::tuple<typename spmd2simd<T, N>::type...>; |
| 105 | +}; |
| 106 | +// * arithmetic type converts to a simd vector with this element type and the |
| 107 | +// width equal to caller's subgroup size and passed as the `N` template |
| 108 | +// argument. |
| 109 | +template <class T, int N> |
| 110 | +struct spmd2simd<T, N, std::enable_if_t<std::is_arithmetic_v<T>>> { |
| 111 | + using type = simd<T, N>; |
| 112 | +}; |
| 113 | + |
| 114 | +// This structure performs the SIMD-to-SPMD return type conversion as defined |
| 115 | +// by the spec. |
| 116 | +template <class, class = void> struct simd2spmd; |
| 117 | +// * `uniform<T>` stays the same |
| 118 | +template <class T> struct simd2spmd<uniform<T>> { |
| 119 | + using type = uniform<T>; |
| 120 | +}; |
| 121 | +// * `simd<T, N>` converts to T |
| 122 | +template <class T, int N> struct simd2spmd<simd<T, N>> { |
| 123 | + using type = T; |
| 124 | +}; |
| 125 | +// * tuple of types converts to tuple of converted tuple element types. |
| 126 | +template <class... T> struct simd2spmd<std::tuple<T...>> { |
| 127 | + using type = std::tuple<typename simd2spmd<T>::type...>; |
| 128 | +}; |
| 129 | +// * arithmetic type T converts to `uniform<T>` |
| 130 | +template <class T> |
| 131 | +struct simd2spmd<T, std::enable_if_t<std::is_arithmetic_v<T>>> { |
| 132 | + using type = uniform<T>; |
| 133 | +}; |
| 134 | + |
| 135 | +// Check if given type is uniform. |
| 136 | +template <class T> struct is_uniform_type : std::false_type {}; |
| 137 | +template <class T> struct is_uniform_type<uniform<T>> : std::true_type { |
| 138 | + using type = T; |
| 139 | +}; |
| 140 | + |
| 141 | +// Check if given type is simd or simd_mask. |
| 142 | +template <class T> struct is_simd_or_mask_type : std::false_type {}; |
| 143 | +template <class T, int N> |
| 144 | +struct is_simd_or_mask_type<simd<T, N>> : std::true_type {}; |
| 145 | +template <class T, int N> |
| 146 | +struct is_simd_or_mask_type<simd_mask<T, N>> : std::true_type {}; |
| 147 | + |
| 148 | +// Checks if the return value type and the types of arguments of given |
| 149 | +// SimdCallable are all uniform. |
| 150 | +template <class SimdCallable, class... SpmdArgs> struct has_uniform_signature { |
| 151 | + constexpr operator bool() { |
| 152 | + using ArgTypeList = __MP11_NS::mp_list<SpmdArgs...>; |
| 153 | + |
| 154 | + if constexpr (__MP11_NS::mp_all_of<ArgTypeList, is_uniform_type>::value) { |
| 155 | + using SimdRet = std::invoke_result_t<SimdCallable, SpmdArgs...>; |
| 156 | + return is_uniform_type<SimdRet>::value || |
| 157 | + !is_simd_or_mask_type<SimdRet>::value; |
| 158 | + } else { |
| 159 | + return false; |
| 160 | + } |
| 161 | + } |
| 162 | +}; |
| 163 | + |
| 164 | +// "Unwraps" a value of the `uniform` type (used before passing to SPMD |
| 165 | +// arguments to the __builtin_invoke_simd): |
| 166 | +// - the case when there is nothing to unwrap |
| 167 | +template <typename T> struct unwrap_uniform { |
| 168 | + static auto impl(T val) { return val; } |
| 169 | +}; |
| 170 | + |
| 171 | +// - the real unwrapping case |
| 172 | +template <typename T> struct unwrap_uniform<uniform<T>> { |
| 173 | + static T impl(uniform<T> val) { return val; } |
| 174 | +}; |
| 175 | + |
| 176 | +// Deduces subgroup size of the caller based on given SIMD callable and |
| 177 | +// corresponding SPMD arguments it is being invoke with via invoke_simd. |
| 178 | +// Basically, for each supported subgroup size, this meta-function finds out if |
| 179 | +// the callable can be invoked by C++ rules given the SPMD arguments transformed |
| 180 | +// as prescribed by the spec assuming this subgroup size. One and only one |
| 181 | +// subgroup size should conform. |
| 182 | +template <class SimdCallable, class... SpmdArgs> struct sg_size { |
| 183 | + template <class N> |
| 184 | + using IsInvocableSgSize = __MP11_NS::mp_bool<std::is_invocable_v< |
| 185 | + SimdCallable, typename spmd2simd<SpmdArgs, N::value>::type...>>; |
| 186 | + |
| 187 | + constexpr operator int() { |
| 188 | + using SupportedSgSizes = __MP11_NS::mp_list_c<int, 1, 2, 4, 8, 16, 32>; |
| 189 | + using InvocableSgSizes = |
| 190 | + __MP11_NS::mp_copy_if<SupportedSgSizes, IsInvocableSgSize>; |
| 191 | + static_assert((__MP11_NS::mp_size<InvocableSgSizes>::value == 1) && |
| 192 | + "no or multiple invoke_simd targets found"); |
| 193 | + return __MP11_NS::mp_front<InvocableSgSizes>::value; |
| 194 | + } |
| 195 | +}; |
| 196 | + |
| 197 | +// Determine the return type of a SIMD callable. |
| 198 | +template <int N, class SimdCallable, class... SpmdArgs> |
| 199 | +using SimdRetType = |
| 200 | + std::invoke_result_t<SimdCallable, |
| 201 | + typename spmd2simd<SpmdArgs, N>::type...>; |
| 202 | +// Determine the return type of an invoke_simd based on the return type of a |
| 203 | +// SIMD callable. |
| 204 | +template <int N, class SimdCallable, class... SpmdArgs> |
| 205 | +using SpmdRetType = |
| 206 | + typename simd2spmd<SimdRetType<N, SimdCallable, SpmdArgs...>>::type; |
| 207 | + |
| 208 | +template <class SimdCallable, class... SpmdArgs> |
| 209 | +static constexpr int get_sg_size() { |
| 210 | + if constexpr (has_uniform_signature<SimdCallable, SpmdArgs...>()) { |
| 211 | + return 0; // subgroup size does not matter then |
| 212 | + } else { |
| 213 | + return sg_size<SimdCallable, SpmdArgs...>(); |
| 214 | + } |
| 215 | +} |
| 216 | + |
| 217 | +// This function is a wrapper around a call to a functor with field or a lambda |
| 218 | +// with captures. Note __regcall - this is needed for efficient argument |
| 219 | +// forwarding. |
| 220 | +template <int N, class Callable, class... T> |
| 221 | +__regcall detail::SimdRetType<N, Callable, T...> |
| 222 | +simd_call_helper(const void *obj_ptr, |
| 223 | + typename detail::spmd2simd<T, N>::type... simd_args) { |
| 224 | + auto f = |
| 225 | + *reinterpret_cast<const std::remove_reference_t<Callable> *>(obj_ptr); |
| 226 | + return f(simd_args...); |
| 227 | +}; |
| 228 | + |
| 229 | +#ifdef _GLIBCXX_RELEASE |
| 230 | +#if _GLIBCXX_RELEASE < 10 |
| 231 | +#define __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA |
| 232 | +#endif // _GLIBCXX_RELEASE < 10 |
| 233 | +#endif // _GLIBCXX_RELEASE |
| 234 | + |
| 235 | +#ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA |
| 236 | +// TODO This is a workaround for libstdc++ version 9 buggy behavior which |
| 237 | +// returns false in the code below. Version 10 works fine. Once required |
| 238 | +// minimum libstdc++ version is bumped to 10, this w/a should be removed. |
| 239 | +// template <class F> bool foo(F &&f) { |
| 240 | +// return std::is_function_v<std::remove_reference_t<F>>; |
| 241 | +// } |
| 242 | +// where F is a function type with __regcall. |
| 243 | +template <class F> struct is_regcall_function_ptr_or_ref : std::false_type {}; |
| 244 | + |
| 245 | +template <class Ret, class... Args> |
| 246 | +struct is_regcall_function_ptr_or_ref<Ret(__regcall &)(Args...)> |
| 247 | + : std::true_type {}; |
| 248 | + |
| 249 | +template <class Ret, class... Args> |
| 250 | +struct is_regcall_function_ptr_or_ref<Ret(__regcall *)(Args...)> |
| 251 | + : std::true_type {}; |
| 252 | + |
| 253 | +template <class Ret, class... Args> |
| 254 | +struct is_regcall_function_ptr_or_ref<Ret(__regcall *&)(Args...)> |
| 255 | + : std::true_type {}; |
| 256 | + |
| 257 | +template <class F> |
| 258 | +static constexpr bool is_regcall_function_ptr_or_ref_v = |
| 259 | + is_regcall_function_ptr_or_ref<F>::value; |
| 260 | +#endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA |
| 261 | + |
| 262 | +template <class Callable> |
| 263 | +static constexpr bool is_function_ptr_or_ref_v = |
| 264 | + std::is_function_v<std::remove_pointer_t<std::remove_reference_t<Callable>>> |
| 265 | +#ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA |
| 266 | + || is_regcall_function_ptr_or_ref_v<Callable> |
| 267 | +#endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA |
| 268 | + ; |
| 269 | +} // namespace detail |
| 270 | + |
| 271 | +// --- The main API |
| 272 | + |
| 273 | +/// The invoke_simd free function invokes a SIMD function using all work-items |
| 274 | +/// in a sub_group. The invoke_simd interface marshals data between the SPMD |
| 275 | +/// context of the calling kernel and the SIMD context of the callee, converting |
| 276 | +/// arguments and return values between scalar and SIMD types as appropriate. |
| 277 | +/// |
| 278 | +/// @param sg the subgroup simd function is invoked from |
| 279 | +/// @param f represents the invoked simd function. |
| 280 | +/// Must be a C++ callable that can be invoked with the same number of |
| 281 | +/// arguments specified in the args parameter pack. Callable may be a function |
| 282 | +/// object, a lambda, or a function pointer (if the device supports |
| 283 | +/// SPV_INTEL_function_pointers). Callable must be an immutable callable with |
| 284 | +/// the same type and state for all work-items in the sub-group, otherwise |
| 285 | +/// behavior is undefined. |
| 286 | +/// @param args SPMD parameters to the invoked function, which undergo |
| 287 | +/// transformation before actual passing to the simd function, as described in |
| 288 | +/// the specification. |
| 289 | +// TODO works only for functions now, enable for other callables. |
| 290 | +template <class Callable, class... T> |
| 291 | +__attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg, |
| 292 | + Callable &&f, T... args) { |
| 293 | + // If the invoke_simd call site is fully uniform, then it does not matter |
| 294 | + // what the subgroup size is and arguments don't need widening and return |
| 295 | + // value does not need shrinking by this library or SPMD compiler, so 0 |
| 296 | + // is fine in this case. |
| 297 | + constexpr int N = detail::get_sg_size<Callable, T...>(); |
| 298 | + using RetSpmd = detail::SpmdRetType<N, Callable, T...>; |
| 299 | + constexpr bool is_function = detail::is_function_ptr_or_ref_v<Callable>; |
| 300 | + |
| 301 | + if constexpr (is_function) { |
| 302 | + return __builtin_invoke_simd<is_function, RetSpmd>( |
| 303 | + f, detail::unwrap_uniform<T>::impl(args)...); |
| 304 | + } else { |
| 305 | + // TODO support functors and lambdas which are handled in this branch. |
| 306 | + // The limiting factor for now is that the LLVMIR data flow analysis |
| 307 | + // implemented in LowerInvokeSimd.cpp which, finds actual invoke_simd |
| 308 | + // target function, can't handle this case yet. |
| 309 | + return __builtin_invoke_simd<is_function, RetSpmd>( |
| 310 | + detail::simd_call_helper<N, Callable, T...>, &f, |
| 311 | + detail::unwrap_uniform<T>::impl(args)...); |
| 312 | + } |
| 313 | +// TODO Temporary macro and assert to enable API compilation testing. |
| 314 | +// LowerInvokeSimd.cpp does not support this case yet. |
| 315 | +#ifndef __INVOKE_SIMD_ENABLE_ALL_CALLABLES |
| 316 | + static_assert(is_function && |
| 317 | + "invoke_simd does not support functors or lambdas yet"); |
| 318 | +#endif // __INVOKE_SIMD_ENABLE_ALL_CALLABLES |
| 319 | +} |
| 320 | + |
| 321 | +} // namespace experimental |
| 322 | +} // namespace oneapi |
| 323 | +} // namespace ext |
| 324 | +} // namespace sycl |
0 commit comments