Skip to content

Commit a37ca84

Browse files
authored
[SYCL] Initial implementation of invoke_simd and uniform extensions. (#5871)
- invoke_simd: https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_invoke_simd.asciidoc Current limitations: * bool parameter type not supported - uniform: https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_uniform.asciidoc Signed-off-by: Konstantin S Bobrovsky <konstantin.s.bobrovsky@intel.com> Contributions from Roland Schulz (@rolandschulz) <roland.schulz@intel.com>
1 parent 433a073 commit a37ca84

File tree

5 files changed

+776
-1
lines changed

5 files changed

+776
-1
lines changed

sycl/include/sycl/ext/intel/esimd/simd.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <sycl/ext/intel/esimd/detail/types.hpp>
2020
#include <sycl/ext/intel/esimd/simd_view.hpp>
2121

22+
#include <sycl/ext/oneapi/experimental/invoke_simd.hpp>
23+
2224
#ifndef __SYCL_DEVICE_ONLY__
2325
#include <iostream>
2426
#endif // __SYCL_DEVICE_ONLY__
@@ -77,6 +79,15 @@ class simd : public detail::simd_obj_impl<
7779
__esimd_dbg_print(simd(const SimdT &RHS));
7880
}
7981

82+
// Implicit conversion constructor from sycl::ext::oneapi::experimental::simd
83+
template <
84+
int N1 = N, class Ty1 = Ty,
85+
class SFINAE = std::enable_if_t<
86+
(N1 == N) && (N1 <= std::experimental::simd_abi::max_fixed_size<
87+
Ty>)&&!detail::is_wrapper_elem_type_v<Ty1>>>
88+
simd(const sycl::ext::oneapi::experimental::simd<Ty, N1> &v)
89+
: simd(static_cast<raw_vector_type>(v)) {}
90+
8091
/// Broadcast constructor with conversion. Converts given value to
8192
/// #element_type and replicates it in all elements.
8293
/// Available when \c T1 is a valid simd element type.
@@ -101,6 +112,19 @@ class simd : public detail::simd_obj_impl<
101112
return detail::convert_scalar<To, element_type>(base_type::data()[0]);
102113
}
103114

115+
/// Implicitly converts this object to a sycl::ext::oneapi::experimental::simd
116+
/// object. Available when the number of elements does not exceed maximum
117+
/// fixed size of the oneapi's simd_abi and (TODO, temporary limitation) the
118+
/// element type is a primitive type (e.g. can't be sycl::half).
119+
template <
120+
int N1, class Ty1 = Ty,
121+
class SFINAE = std::enable_if_t<
122+
(N1 == N) && (N1 <= std::experimental::simd_abi::max_fixed_size<
123+
Ty>)&&!detail::is_wrapper_elem_type_v<Ty1>>>
124+
operator sycl::ext::oneapi::experimental::simd<Ty, N1>() {
125+
return sycl::ext::oneapi::experimental::simd<Ty, N1>(base_type::data());
126+
}
127+
104128
/// Prefix increment, increments elements of this object.
105129
/// @return Reference to this object.
106130
simd &operator++() {
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
//==------ invoke_simd.hpp - SYCL invoke_simd extension --*- C++ -*---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// ===--------------------------------------------------------------------=== //
8+
// Implemenation of the sycl_ext_oneapi_invoke_simd extension.
9+
// https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_invoke_simd.asciidoc
10+
// ===--------------------------------------------------------------------=== //
11+
12+
#pragma once
13+
14+
// SYCL extension macro definition as required by the SYCL specification.
15+
// 1 - Initial extension version. Base features are supported.
16+
#define SYCL_EXT_ONEAPI_INVOKE_SIMD 1
17+
18+
#include <sycl/ext/oneapi/experimental/uniform.hpp>
19+
20+
#include <CL/sycl/sub_group.hpp>
21+
#include <std/experimental/simd.hpp>
22+
#include <sycl/detail/boost/mp11.hpp>
23+
24+
#include <functional>
25+
26+
// TODOs:
27+
// * (a) TODO bool translation in spmd2simd.
28+
// * (b) TODO enforce constness of a functor/lambda's () operator
29+
// * (c) TODO support lambdas and functors in BE
30+
31+
/// Middle End - to - Back End interface to invoke explicit SIMD functions from
32+
/// SPMD SYCL context. Must not be used by user code. BEs are expected to
33+
/// recognize this intrinsic and transform the intrinsic call with a direct call
34+
/// to the SIMD target, as well as process SPMD arguments in the way described
35+
/// in the specification for `invoke_simd`.
36+
/// @tparam SpmdRet the return type. Can be `uniform<T>`.
37+
/// @tparam SimdCallee the type of the SIMD callee function (the "target"). Must
38+
/// be a function type (not lambda or functor).
39+
/// @tparam SpmdArgs The original SPMD arguments passed to the invoke_simd.
40+
template <bool IsFunc, class SpmdRet, class SimdCallee, class... SpmdArgs,
41+
class = std::enable_if_t<!IsFunc>>
42+
SYCL_EXTERNAL __regcall SpmdRet
43+
__builtin_invoke_simd(SimdCallee target, const void *obj, SpmdArgs... args)
44+
#ifdef __SYCL_DEVICE_ONLY__
45+
;
46+
#else
47+
{
48+
// __builtin_invoke_simd is not supported on the host device yet
49+
throw sycl::feature_not_supported();
50+
}
51+
#endif // __SYCL_DEVICE_ONLY__
52+
53+
template <bool IsFunc, class SpmdRet, class SimdCallee, class... SpmdArgs,
54+
class = std::enable_if_t<IsFunc>>
55+
SYCL_EXTERNAL __regcall SpmdRet __builtin_invoke_simd(SimdCallee target,
56+
SpmdArgs... args)
57+
#ifdef __SYCL_DEVICE_ONLY__
58+
;
59+
#else
60+
{
61+
// __builtin_invoke_simd is not supported on the host device yet
62+
throw sycl::feature_not_supported();
63+
}
64+
#endif // __SYCL_DEVICE_ONLY__
65+
66+
namespace sycl {
67+
namespace ext {
68+
namespace oneapi {
69+
namespace experimental {
70+
71+
// --- Basic definitions prescribed by the spec.
72+
namespace simd_abi {
73+
// "Fixed-size simd width of N" ABI based on clang vectors - used as the ABI for
74+
// SIMD objects this implementation of invoke_simd spec is based on.
75+
template <class T, int N>
76+
using native_fixed_size = typename std::experimental::__simd_abi<
77+
std::experimental::_StorageKind::_VecExt, N>;
78+
} // namespace simd_abi
79+
80+
// The SIMD object type, which is the generic std::experimental::simd type with
81+
// the native fixed size ABI.
82+
template <class T, int N>
83+
using simd = std::experimental::simd<T, simd_abi::native_fixed_size<T, N>>;
84+
85+
// The SIMD mask object type.
86+
template <class T, int N>
87+
using simd_mask =
88+
std::experimental::simd_mask<T, simd_abi::native_fixed_size<T, N>>;
89+
90+
// --- Helpers
91+
namespace detail {
92+
93+
namespace __MP11_NS = sycl::detail::boost::mp11;
94+
95+
// This structure performs the SPMD-to-SIMD parameter type conversion as defined
96+
// by the spec.
97+
template <class T, int N, class = void> struct spmd2simd;
98+
// * `uniform<T>` converts to `T`
99+
template <class T, int N> struct spmd2simd<uniform<T>, N> {
100+
using type = T;
101+
};
102+
// * tuple of types converts to tuple of converted tuple element types.
103+
template <class... T, int N> struct spmd2simd<std::tuple<T...>, N> {
104+
using type = std::tuple<typename spmd2simd<T, N>::type...>;
105+
};
106+
// * arithmetic type converts to a simd vector with this element type and the
107+
// width equal to caller's subgroup size and passed as the `N` template
108+
// argument.
109+
template <class T, int N>
110+
struct spmd2simd<T, N, std::enable_if_t<std::is_arithmetic_v<T>>> {
111+
using type = simd<T, N>;
112+
};
113+
114+
// This structure performs the SIMD-to-SPMD return type conversion as defined
115+
// by the spec.
116+
template <class, class = void> struct simd2spmd;
117+
// * `uniform<T>` stays the same
118+
template <class T> struct simd2spmd<uniform<T>> {
119+
using type = uniform<T>;
120+
};
121+
// * `simd<T, N>` converts to T
122+
template <class T, int N> struct simd2spmd<simd<T, N>> {
123+
using type = T;
124+
};
125+
// * tuple of types converts to tuple of converted tuple element types.
126+
template <class... T> struct simd2spmd<std::tuple<T...>> {
127+
using type = std::tuple<typename simd2spmd<T>::type...>;
128+
};
129+
// * arithmetic type T converts to `uniform<T>`
130+
template <class T>
131+
struct simd2spmd<T, std::enable_if_t<std::is_arithmetic_v<T>>> {
132+
using type = uniform<T>;
133+
};
134+
135+
// Check if given type is uniform.
136+
template <class T> struct is_uniform_type : std::false_type {};
137+
template <class T> struct is_uniform_type<uniform<T>> : std::true_type {
138+
using type = T;
139+
};
140+
141+
// Check if given type is simd or simd_mask.
142+
template <class T> struct is_simd_or_mask_type : std::false_type {};
143+
template <class T, int N>
144+
struct is_simd_or_mask_type<simd<T, N>> : std::true_type {};
145+
template <class T, int N>
146+
struct is_simd_or_mask_type<simd_mask<T, N>> : std::true_type {};
147+
148+
// Checks if the return value type and the types of arguments of given
149+
// SimdCallable are all uniform.
150+
template <class SimdCallable, class... SpmdArgs> struct has_uniform_signature {
151+
constexpr operator bool() {
152+
using ArgTypeList = __MP11_NS::mp_list<SpmdArgs...>;
153+
154+
if constexpr (__MP11_NS::mp_all_of<ArgTypeList, is_uniform_type>::value) {
155+
using SimdRet = std::invoke_result_t<SimdCallable, SpmdArgs...>;
156+
return is_uniform_type<SimdRet>::value ||
157+
!is_simd_or_mask_type<SimdRet>::value;
158+
} else {
159+
return false;
160+
}
161+
}
162+
};
163+
164+
// "Unwraps" a value of the `uniform` type (used before passing to SPMD
165+
// arguments to the __builtin_invoke_simd):
166+
// - the case when there is nothing to unwrap
167+
template <typename T> struct unwrap_uniform {
168+
static auto impl(T val) { return val; }
169+
};
170+
171+
// - the real unwrapping case
172+
template <typename T> struct unwrap_uniform<uniform<T>> {
173+
static T impl(uniform<T> val) { return val; }
174+
};
175+
176+
// Deduces subgroup size of the caller based on given SIMD callable and
177+
// corresponding SPMD arguments it is being invoke with via invoke_simd.
178+
// Basically, for each supported subgroup size, this meta-function finds out if
179+
// the callable can be invoked by C++ rules given the SPMD arguments transformed
180+
// as prescribed by the spec assuming this subgroup size. One and only one
181+
// subgroup size should conform.
182+
template <class SimdCallable, class... SpmdArgs> struct sg_size {
183+
template <class N>
184+
using IsInvocableSgSize = __MP11_NS::mp_bool<std::is_invocable_v<
185+
SimdCallable, typename spmd2simd<SpmdArgs, N::value>::type...>>;
186+
187+
constexpr operator int() {
188+
using SupportedSgSizes = __MP11_NS::mp_list_c<int, 1, 2, 4, 8, 16, 32>;
189+
using InvocableSgSizes =
190+
__MP11_NS::mp_copy_if<SupportedSgSizes, IsInvocableSgSize>;
191+
static_assert((__MP11_NS::mp_size<InvocableSgSizes>::value == 1) &&
192+
"no or multiple invoke_simd targets found");
193+
return __MP11_NS::mp_front<InvocableSgSizes>::value;
194+
}
195+
};
196+
197+
// Determine the return type of a SIMD callable.
198+
template <int N, class SimdCallable, class... SpmdArgs>
199+
using SimdRetType =
200+
std::invoke_result_t<SimdCallable,
201+
typename spmd2simd<SpmdArgs, N>::type...>;
202+
// Determine the return type of an invoke_simd based on the return type of a
203+
// SIMD callable.
204+
template <int N, class SimdCallable, class... SpmdArgs>
205+
using SpmdRetType =
206+
typename simd2spmd<SimdRetType<N, SimdCallable, SpmdArgs...>>::type;
207+
208+
template <class SimdCallable, class... SpmdArgs>
209+
static constexpr int get_sg_size() {
210+
if constexpr (has_uniform_signature<SimdCallable, SpmdArgs...>()) {
211+
return 0; // subgroup size does not matter then
212+
} else {
213+
return sg_size<SimdCallable, SpmdArgs...>();
214+
}
215+
}
216+
217+
// This function is a wrapper around a call to a functor with field or a lambda
218+
// with captures. Note __regcall - this is needed for efficient argument
219+
// forwarding.
220+
template <int N, class Callable, class... T>
221+
__regcall detail::SimdRetType<N, Callable, T...>
222+
simd_call_helper(const void *obj_ptr,
223+
typename detail::spmd2simd<T, N>::type... simd_args) {
224+
auto f =
225+
*reinterpret_cast<const std::remove_reference_t<Callable> *>(obj_ptr);
226+
return f(simd_args...);
227+
};
228+
229+
#ifdef _GLIBCXX_RELEASE
230+
#if _GLIBCXX_RELEASE < 10
231+
#define __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
232+
#endif // _GLIBCXX_RELEASE < 10
233+
#endif // _GLIBCXX_RELEASE
234+
235+
#ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
236+
// TODO This is a workaround for libstdc++ version 9 buggy behavior which
237+
// returns false in the code below. Version 10 works fine. Once required
238+
// minimum libstdc++ version is bumped to 10, this w/a should be removed.
239+
// template <class F> bool foo(F &&f) {
240+
// return std::is_function_v<std::remove_reference_t<F>>;
241+
// }
242+
// where F is a function type with __regcall.
243+
template <class F> struct is_regcall_function_ptr_or_ref : std::false_type {};
244+
245+
template <class Ret, class... Args>
246+
struct is_regcall_function_ptr_or_ref<Ret(__regcall &)(Args...)>
247+
: std::true_type {};
248+
249+
template <class Ret, class... Args>
250+
struct is_regcall_function_ptr_or_ref<Ret(__regcall *)(Args...)>
251+
: std::true_type {};
252+
253+
template <class Ret, class... Args>
254+
struct is_regcall_function_ptr_or_ref<Ret(__regcall *&)(Args...)>
255+
: std::true_type {};
256+
257+
template <class F>
258+
static constexpr bool is_regcall_function_ptr_or_ref_v =
259+
is_regcall_function_ptr_or_ref<F>::value;
260+
#endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
261+
262+
template <class Callable>
263+
static constexpr bool is_function_ptr_or_ref_v =
264+
std::is_function_v<std::remove_pointer_t<std::remove_reference_t<Callable>>>
265+
#ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
266+
|| is_regcall_function_ptr_or_ref_v<Callable>
267+
#endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
268+
;
269+
} // namespace detail
270+
271+
// --- The main API
272+
273+
/// The invoke_simd free function invokes a SIMD function using all work-items
274+
/// in a sub_group. The invoke_simd interface marshals data between the SPMD
275+
/// context of the calling kernel and the SIMD context of the callee, converting
276+
/// arguments and return values between scalar and SIMD types as appropriate.
277+
///
278+
/// @param sg the subgroup simd function is invoked from
279+
/// @param f represents the invoked simd function.
280+
/// Must be a C++ callable that can be invoked with the same number of
281+
/// arguments specified in the args parameter pack. Callable may be a function
282+
/// object, a lambda, or a function pointer (if the device supports
283+
/// SPV_INTEL_function_pointers). Callable must be an immutable callable with
284+
/// the same type and state for all work-items in the sub-group, otherwise
285+
/// behavior is undefined.
286+
/// @param args SPMD parameters to the invoked function, which undergo
287+
/// transformation before actual passing to the simd function, as described in
288+
/// the specification.
289+
// TODO works only for functions now, enable for other callables.
290+
template <class Callable, class... T>
291+
__attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
292+
Callable &&f, T... args) {
293+
// If the invoke_simd call site is fully uniform, then it does not matter
294+
// what the subgroup size is and arguments don't need widening and return
295+
// value does not need shrinking by this library or SPMD compiler, so 0
296+
// is fine in this case.
297+
constexpr int N = detail::get_sg_size<Callable, T...>();
298+
using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
299+
constexpr bool is_function = detail::is_function_ptr_or_ref_v<Callable>;
300+
301+
if constexpr (is_function) {
302+
return __builtin_invoke_simd<is_function, RetSpmd>(
303+
f, detail::unwrap_uniform<T>::impl(args)...);
304+
} else {
305+
// TODO support functors and lambdas which are handled in this branch.
306+
// The limiting factor for now is that the LLVMIR data flow analysis
307+
// implemented in LowerInvokeSimd.cpp which, finds actual invoke_simd
308+
// target function, can't handle this case yet.
309+
return __builtin_invoke_simd<is_function, RetSpmd>(
310+
detail::simd_call_helper<N, Callable, T...>, &f,
311+
detail::unwrap_uniform<T>::impl(args)...);
312+
}
313+
// TODO Temporary macro and assert to enable API compilation testing.
314+
// LowerInvokeSimd.cpp does not support this case yet.
315+
#ifndef __INVOKE_SIMD_ENABLE_ALL_CALLABLES
316+
static_assert(is_function &&
317+
"invoke_simd does not support functors or lambdas yet");
318+
#endif // __INVOKE_SIMD_ENABLE_ALL_CALLABLES
319+
}
320+
321+
} // namespace experimental
322+
} // namespace oneapi
323+
} // namespace ext
324+
} // namespace sycl

0 commit comments

Comments
 (0)