Skip to content

Commit 4db72d5

Browse files
authored
[SYCL] Lift restriction that prevents us from using free function kernel queries in device compilation mode (#17398)
Free function kernel queries can, at the moment, only be used in host compilation mode requiring us to add `#ifndef __SYCL_DEVICE_ONLY__ ` directives to get code that uses them to compile. This PR removes this restriction and the instances of the workaround used in our tests. My understanding is the following: The current implementation employs SFINAE using the predicate `is_kernel_v<Func>` to delete at compile time queries that are instantiated with a `Func` that is not a free function kernel. This is certainly a reasonable interpretation of the spec which says : `Constraints: Available only if is_kernel_v<Func> is true.` for each query. However, this causes a problem. During the device front-end compiler phase of our toolchain, information is gathered regarding what the free function kernels are as defined in the code provided by the user. This means during this phase itself, `Func` will be treated as an invalid free function kernel name since this information becomes known only after the front-end has finished executing. And because SFINAE interacts with the front-end mechanisms, we get a hard error in this phase of compilation about a missing function. One simple fix, proposed in this PR, is to simply define the trait `is_kernel_v<Func>` to be constantly true in device compilation mode regardless of `Func`. This will ensure device front-end succeeds since a declaration of the query exists and it still interprets the spec reasonably because if `Func` does not represent a kernel, then during host compilation the declaration will not be available because the `is_kernel` trait fails which will print a nice message to the user saying what went wrong.
1 parent 6dbe186 commit 4db72d5

22 files changed

+20
-94
lines changed

sycl/include/sycl/ext/oneapi/experimental/free_function_traits.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,21 @@ inline constexpr bool is_single_task_kernel_v =
2929
is_single_task_kernel<Func>::value;
3030

3131
template <auto *Func> struct is_kernel {
32+
// During device compilation mode the compiler does not yet know
33+
// what the kernels are named because that is exactly what its trying to
34+
// figure out during this phase. Therefore, we set the is_kernel trait to true
35+
// by default during device compilation in order to not get missing functions
36+
// errors.
37+
#ifdef __SYCL_DEVICE_ONLY__
38+
static constexpr bool value = true;
39+
#else
3240
static constexpr bool value = false;
41+
#endif
3342
};
3443

3544
template <auto *Func>
3645
inline constexpr bool is_kernel_v = is_kernel<Func>::value;
3746

3847
} // namespace ext::oneapi::experimental
3948
} // namespace _V1
40-
} // namespace sycl
49+
} // namespace sycl

sycl/include/sycl/kernel_bundle.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,6 @@ class kernel_bundle : public detail::kernel_bundle_plain,
491491
return detail::kernel_bundle_plain::ext_oneapi_has_kernel(name);
492492
}
493493

494-
// For free functions.
495494
template <auto *Func>
496495
std::enable_if_t<ext::oneapi::experimental::is_kernel_v<Func>, bool>
497496
ext_oneapi_has_kernel() {
@@ -838,7 +837,6 @@ bool has_kernel_bundle(const context &Ctx, const std::vector<device> &Devs) {
838837
return has_kernel_bundle<State>(Ctx, Devs, {get_kernel_id<KernelName>()});
839838
}
840839

841-
// For free functions.
842840
namespace ext::oneapi::experimental {
843841
template <auto *Func, bundle_state State>
844842
std::enable_if_t<is_kernel_v<Func>, bool>
@@ -866,7 +864,6 @@ template <typename KernelName> bool is_compatible(const device &Dev) {
866864
return is_compatible({get_kernel_id<KernelName>()}, Dev);
867865
}
868866

869-
// For free functions.
870867
namespace ext::oneapi::experimental {
871868
template <auto *Func>
872869
std::enable_if_t<is_kernel_v<Func>, bool> is_compatible(const device &Dev) {

sycl/test-e2e/DeviceImageBackendContent/L0_interop_test.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <sycl/detail/core.hpp>
77
#include <sycl/ext/oneapi/backend/level_zero.hpp>
88
#include <sycl/ext/oneapi/free_function_queries.hpp>
9+
#include <sycl/kernel_bundle.hpp>
910
#include <sycl/usm.hpp>
1011
#include <vector>
1112

@@ -22,10 +23,6 @@ int main() {
2223
sycl::queue q;
2324
sycl::context ctxt = q.get_context();
2425
sycl::device d = ctxt.get_devices()[0];
25-
// The following ifndef is required due to a number of limitations of free
26-
// function kernels. See CMPLRLLVM-61498.
27-
// TODO: Remove it once these limitations are no longer there.
28-
#ifndef __SYCL_DEVICE_ONLY__
2926
// First, run the kernel using the SYCL API.
3027
auto bundle = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctxt);
3128
sycl::kernel_id iota_id = syclexp::get_kernel_id<iota>();
@@ -90,5 +87,4 @@ int main() {
9087
assert(*ptr_twin == *ptr);
9188
sycl::free(ptr, q);
9289
sycl::free(ptr_twin, q);
93-
#endif
9490
}

sycl/test-e2e/DeviceImageBackendContent/OCL_interop_test.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <sycl/detail/cl.h>
99
#include <sycl/detail/core.hpp>
1010
#include <sycl/ext/oneapi/free_function_queries.hpp>
11+
#include <sycl/kernel_bundle.hpp>
1112
#include <sycl/usm.hpp>
1213
#include <vector>
1314

@@ -24,10 +25,6 @@ int main() {
2425
sycl::queue q;
2526
sycl::context ctxt = q.get_context();
2627
sycl::device d = ctxt.get_devices()[0];
27-
// The following ifndef is required due to a number of limitations of free
28-
// function kernels. See CMPLRLLVM-61498.
29-
// TODO: Remove it once these limitations are no longer there.
30-
#ifndef __SYCL_DEVICE_ONLY__
3128
// First, run the kernel using the SYCL API.
3229

3330
auto bundle = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctxt);
@@ -77,5 +74,4 @@ int main() {
7774
assert(*ptr_twin == *ptr);
7875
sycl::free(ptr, q);
7976
sycl::free(ptr_twin, q);
80-
#endif
8177
}

sycl/test-e2e/DeviceImageDependencies/NewOffloadDriver/free_function_kernels.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <iostream>
1515
#include <sycl/detail/core.hpp>
1616
#include <sycl/ext/oneapi/free_function_queries.hpp>
17+
#include <sycl/kernel_bundle.hpp>
1718
#include <sycl/usm.hpp>
1819

1920
using namespace sycl;
@@ -79,8 +80,6 @@ bool test_0(queue Queue) {
7980
std::cout << "Test 0a: " << (PassA ? "PASS" : "FAIL") << std::endl;
8081

8182
bool PassB = false;
82-
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
83-
#ifndef __SYCL_DEVICE_ONLY__
8483
kernel_bundle Bundle =
8584
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
8685
kernel_id Kernel_id = ext::oneapi::experimental::get_kernel_id<ff_0>();
@@ -98,7 +97,6 @@ bool test_0(queue Queue) {
9897
std::cout << "Test 0b: " << (PassB ? "PASS" : "FAIL") << std::endl;
9998

10099
free(usmPtr, Queue);
101-
#endif
102100
return PassA && PassB;
103101
}
104102

@@ -131,8 +129,6 @@ bool test_1(queue Queue) {
131129
std::cout << "Test 1a: " << (PassA ? "PASS" : "FAIL") << std::endl;
132130

133131
bool PassB = false;
134-
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
135-
#ifndef __SYCL_DEVICE_ONLY__
136132
kernel_bundle Bundle =
137133
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
138134
kernel_id Kernel_id = ext::oneapi::experimental::get_kernel_id<(
@@ -151,7 +147,6 @@ bool test_1(queue Queue) {
151147
std::cout << "Test 1b: " << (PassB ? "PASS" : "FAIL") << std::endl;
152148

153149
free(usmPtr, Queue);
154-
#endif
155150
return PassA && PassB;
156151
}
157152

@@ -189,8 +184,6 @@ bool test_2(queue Queue) {
189184
std::cout << "Test 2a: " << (PassA ? "PASS" : "FAIL") << std::endl;
190185

191186
bool PassB = false;
192-
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
193-
#ifndef __SYCL_DEVICE_ONLY__
194187
kernel_bundle Bundle =
195188
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
196189
kernel_id Kernel_id =
@@ -208,7 +201,6 @@ bool test_2(queue Queue) {
208201
std::cout << "Test 2b: " << (PassB ? "PASS" : "FAIL") << std::endl;
209202

210203
free(usmPtr, Queue);
211-
#endif
212204
return PassA && PassB;
213205
}
214206

@@ -250,8 +242,6 @@ bool test_3(queue Queue) {
250242
std::cout << "Test 3a: " << (PassA ? "PASS" : "FAIL") << std::endl;
251243

252244
bool PassB = false;
253-
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
254-
#ifndef __SYCL_DEVICE_ONLY__
255245
kernel_bundle Bundle =
256246
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
257247
kernel_id Kernel_id = ext::oneapi::experimental::get_kernel_id<(
@@ -269,7 +259,6 @@ bool test_3(queue Queue) {
269259
std::cout << "Test 3b: " << (PassB ? "PASS" : "FAIL") << std::endl;
270260

271261
free(usmPtr, Queue);
272-
#endif
273262
return PassA && PassB;
274263
}
275264

sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <iostream>
1414
#include <sycl/detail/core.hpp>
1515
#include <sycl/ext/oneapi/free_function_queries.hpp>
16+
#include <sycl/kernel_bundle.hpp>
1617
#include <sycl/usm.hpp>
1718

1819
using namespace sycl;
@@ -78,8 +79,6 @@ bool test_0(queue Queue) {
7879
std::cout << "Test 0a: " << (PassA ? "PASS" : "FAIL") << std::endl;
7980

8081
bool PassB = false;
81-
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
82-
#ifndef __SYCL_DEVICE_ONLY__
8382
kernel_bundle Bundle =
8483
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
8584
kernel_id Kernel_id = ext::oneapi::experimental::get_kernel_id<ff_0>();
@@ -97,7 +96,6 @@ bool test_0(queue Queue) {
9796
std::cout << "Test 0b: " << (PassB ? "PASS" : "FAIL") << std::endl;
9897

9998
free(usmPtr, Queue);
100-
#endif
10199
return PassA && PassB;
102100
}
103101

@@ -130,8 +128,6 @@ bool test_1(queue Queue) {
130128
std::cout << "Test 1a: " << (PassA ? "PASS" : "FAIL") << std::endl;
131129

132130
bool PassB = false;
133-
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
134-
#ifndef __SYCL_DEVICE_ONLY__
135131
kernel_bundle Bundle =
136132
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
137133
kernel_id Kernel_id = ext::oneapi::experimental::get_kernel_id<(
@@ -150,7 +146,6 @@ bool test_1(queue Queue) {
150146
std::cout << "Test 1b: " << (PassB ? "PASS" : "FAIL") << std::endl;
151147

152148
free(usmPtr, Queue);
153-
#endif
154149
return PassA && PassB;
155150
}
156151

@@ -188,8 +183,6 @@ bool test_2(queue Queue) {
188183
std::cout << "Test 2a: " << (PassA ? "PASS" : "FAIL") << std::endl;
189184

190185
bool PassB = false;
191-
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
192-
#ifndef __SYCL_DEVICE_ONLY__
193186
kernel_bundle Bundle =
194187
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
195188
kernel_id Kernel_id =
@@ -207,7 +200,6 @@ bool test_2(queue Queue) {
207200
std::cout << "Test 2b: " << (PassB ? "PASS" : "FAIL") << std::endl;
208201

209202
free(usmPtr, Queue);
210-
#endif
211203
return PassA && PassB;
212204
}
213205

@@ -249,8 +241,6 @@ bool test_3(queue Queue) {
249241
std::cout << "Test 3a: " << (PassA ? "PASS" : "FAIL") << std::endl;
250242

251243
bool PassB = false;
252-
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
253-
#ifndef __SYCL_DEVICE_ONLY__
254244
kernel_bundle Bundle =
255245
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
256246
kernel_id Kernel_id = ext::oneapi::experimental::get_kernel_id<(
@@ -268,7 +258,6 @@ bool test_3(queue Queue) {
268258
std::cout << "Test 3b: " << (PassB ? "PASS" : "FAIL") << std::endl;
269259

270260
free(usmPtr, Queue);
271-
#endif
272261
return PassA && PassB;
273262
}
274263

sycl/test-e2e/Graph/Inputs/free_function_kernels.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ int main() {
2121

2222
Queue.memset(PtrA, 0, Size * sizeof(int)).wait();
2323

24-
#ifndef __SYCL_DEVICE_ONLY__
2524
kernel_bundle Bundle = get_kernel_bundle<bundle_state::executable>(Ctxt);
2625
kernel_id Kernel_id = exp_ext::get_kernel_id<ff_0>();
2726
kernel Kernel = Bundle.get_kernel(Kernel_id);
@@ -38,7 +37,6 @@ int main() {
3837
for (size_t i = 0; i < Size; i++) {
3938
assert(HostDataA[i] == i);
4039
}
41-
#endif
4240
sycl::free(PtrA, Queue);
4341

4442
return 0;

sycl/test-e2e/Graph/Inputs/work_group_memory_free_function.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ int main() {
2727

2828
const size_t LocalSize = 128;
2929

30-
#ifndef __SYCL_DEVICE_ONLY__
3130
kernel_bundle Bundle =
3231
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
3332
kernel_id Kernel_id = exp_ext::get_kernel_id<ff_local_mem>();
@@ -57,7 +56,6 @@ int main() {
5756
int Ref = 10 + i + (Iterations * (i * 2));
5857
assert(check_value(i, Ref, HostData[i], "Ptr"));
5958
}
60-
#endif
6159

6260
free(Ptr, Queue);
6361
return 0;

sycl/test-e2e/Graph/Update/FreeFunctionKernels/update_before_finalize.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ int main() {
3030

3131
exp_ext::dynamic_parameter InputParam(Graph, PtrA);
3232

33-
#ifndef __SYCL_DEVICE_ONLY__
3433
kernel_bundle Bundle = get_kernel_bundle<bundle_state::executable>(Ctxt);
3534
kernel_id Kernel_id = exp_ext::get_kernel_id<ff_0>();
3635
kernel Kernel = Bundle.get_kernel(Kernel_id);
@@ -52,7 +51,6 @@ int main() {
5251
assert(HostDataA[i] == 0);
5352
assert(HostDataB[i] == i);
5453
}
55-
#endif
5654
sycl::free(PtrA, Queue);
5755
sycl::free(PtrB, Queue);
5856

sycl/test-e2e/Graph/Update/FreeFunctionKernels/update_with_indices_multiple_exec_graphs.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ int main() {
3131

3232
exp_ext::dynamic_parameter InputParam(Graph, PtrA);
3333

34-
#ifndef __SYCL_DEVICE_ONLY__
3534
kernel_bundle Bundle = get_kernel_bundle<bundle_state::executable>(Ctxt);
3635
kernel_id Kernel_id = exp_ext::get_kernel_id<ff_1>();
3736
kernel Kernel = Bundle.get_kernel(Kernel_id);
@@ -69,7 +68,6 @@ int main() {
6968
assert(HostDataA[i] == i * 3);
7069
assert(HostDataB[i] == i);
7170
}
72-
#endif
7371
sycl::free(PtrA, Queue);
7472
sycl::free(PtrB, Queue);
7573

0 commit comments

Comments
 (0)