Skip to content

Commit 291eeee

Browse files
authored
[SYCLCompat] Optimize/(fix?) permute_sub_group_by_xor if logical_sub_group_size == 32 (#16646)
`syclcompat::permute_sub_group_by_xor` was reported to flakily fail on L0. Closer inspection revealed that the implementation of `permute_sub_group_by_xor` is incorrect for cases where `logical_sub_group_size != 32`, which is one of the test cases. This implies that the test itself is wrong. In this PR we first optimize the part of the implementation that is valid assuming that Intel spirv builtins are correct (which is also the only case realistically a user will program): case `logical_sub_group_size == 32`, in order to: - Ensure the only useful case is working via the correct optimized route. - Check that this improvement doesn't break the suspicious test. A follow on PR can fix the other cases where `logical_sub_group_size != 32`: this is better to do later, since - the only use case I know of for this is to implement non-uniform group algorithms that we already have implemented (e.g. see #9671) and any user is advised to use such algorithms instead of reimplementing them themselves. - This must I think require a complete reworking of the test and would otherwise delay the more important change here. --------- Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
1 parent e257292 commit 291eeee

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

sycl/include/syclcompat/util.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,9 @@ T shift_sub_group_right(sycl::sub_group g, T x, unsigned int delta,
410410
template <typename T>
411411
T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
412412
int logical_sub_group_size = 32) {
413+
if (logical_sub_group_size == 32) {
414+
return permute_group_by_xor(g, x, mask);
415+
}
413416
unsigned int id = g.get_local_linear_id();
414417
unsigned int start_index =
415418
id / logical_sub_group_size * logical_sub_group_size;

sycl/test-e2e/syclcompat/util/util_permute_sub_group_by_xor.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,9 @@ void test_permute_sub_group_by_xor() {
8686
syclcompat::device_ext &dev_ct1 = syclcompat::get_current_device();
8787
sycl::queue *q_ct1 = dev_ct1.default_queue();
8888
bool Result = true;
89-
int *dev_data = nullptr;
9089
unsigned int *dev_data_u = nullptr;
9190
sycl::range<3> GridSize(1, 1, 1);
9291
sycl::range<3> BlockSize(1, 1, 1);
93-
dev_data = sycl::malloc_device<int>(DATA_NUM, *q_ct1);
9492
dev_data_u = sycl::malloc_device<unsigned int>(DATA_NUM, *q_ct1);
9593

9694
GridSize = sycl::range<3>(1, 1, 2);
@@ -120,6 +118,19 @@ void test_permute_sub_group_by_xor() {
120118
q_ct1->memcpy(host_dev_data_u, dev_data_u, DATA_NUM * sizeof(unsigned int))
121119
.wait();
122120
verify_data<unsigned int>(host_dev_data_u, expect1, DATA_NUM);
121+
sycl::free(dev_data_u, *q_ct1);
122+
}
123+
124+
void test_permute_sub_group_by_xor_extra_arg() {
125+
std::cout << __PRETTY_FUNCTION__ << std::endl;
126+
127+
syclcompat::device_ext &dev_ct1 = syclcompat::get_current_device();
128+
sycl::queue *q_ct1 = dev_ct1.default_queue();
129+
bool Result = true;
130+
unsigned int *dev_data_u = nullptr;
131+
sycl::range<3> GridSize(1, 1, 1);
132+
sycl::range<3> BlockSize(1, 1, 1);
133+
dev_data_u = sycl::malloc_device<unsigned int>(DATA_NUM, *q_ct1);
123134

124135
GridSize = sycl::range<3>(1, 1, 2);
125136
BlockSize = sycl::range<3>(1, 2, 32);
@@ -133,6 +144,7 @@ void test_permute_sub_group_by_xor() {
133144
91, 90, 93, 92, 95, 94, 97, 96, 99, 98, 101, 100, 103, 102, 105,
134145
104, 107, 106, 109, 108, 111, 110, 113, 112, 115, 114, 117, 116, 119, 118,
135146
121, 120, 123, 122, 125, 124, 127, 126};
147+
unsigned int host_dev_data_u[DATA_NUM];
136148
init_data<unsigned int>(host_dev_data_u, DATA_NUM);
137149

138150
q_ct1->memcpy(dev_data_u, host_dev_data_u, DATA_NUM * sizeof(unsigned int))
@@ -147,13 +159,12 @@ void test_permute_sub_group_by_xor() {
147159
q_ct1->memcpy(host_dev_data_u, dev_data_u, DATA_NUM * sizeof(unsigned int))
148160
.wait();
149161
verify_data<unsigned int>(host_dev_data_u, expect2, DATA_NUM);
150-
151-
sycl::free(dev_data, *q_ct1);
152162
sycl::free(dev_data_u, *q_ct1);
153163
}
154164

155165
int main() {
156166
test_permute_sub_group_by_xor();
167+
test_permute_sub_group_by_xor_extra_arg();
157168

158169
return 0;
159170
}

0 commit comments

Comments
 (0)