Skip to content

Commit baed6a5

Browse files
authored
[SYCL][CUDA] Add basic sub-group functionality (#2587)
Implements: - Sub-group local id - Sub-group id - Number of sub-groups - Sub-group size - Max sub-group size The implementations are functionally correct, but may benefit from additional optimization. Signed-off-by: John Pennycook <john.pennycook@intel.com> --- The implementation is different to the one proposed in https://intel.github.io/llvm-docs/cuda/opencl-subgroup-vs-cuda-crosslane-op.html, because I don't think `sreg.warpid` and `sreg.nwarpid` have the correct semantics for sub-groups. The mapping from work-items to sub-groups is invariant during a kernel's execution, which isn't true of the warp ID in PTX. As far as I can tell, the number of warp IDs represents the maximum number of warps that can execute in a CTA rather than the number of warps in a CTA. NVIDIA's [PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#special-registers-warpid) says that `tid` should be used to compute a "virtual warp ID" if one is required, which is what I've implemented. I convinced myself that we have to compute the sub-group IDs and sizes from the linear size of the work-group, and couldn't find a simpler way to express this. Ideally, we wouldn't have to re-compute each of these values on every call. It would be sufficient to compute them once at the start of the kernel and then re-use them, but I don't have enough knowledge of Clang/LLVM/libclc to implement that.
1 parent 15cac43 commit baed6a5

File tree

15 files changed

+185
-20
lines changed

15 files changed

+185
-20
lines changed

libclc/generic/include/spirv/spirv.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,18 @@
3737
#include <macros.h>
3838

3939
/* 6.11.1 Work-Item Functions */
40-
#include <spirv/workitem/get_global_size.h>
4140
#include <spirv/workitem/get_global_id.h>
42-
#include <spirv/workitem/get_local_size.h>
41+
#include <spirv/workitem/get_global_offset.h>
42+
#include <spirv/workitem/get_global_size.h>
43+
#include <spirv/workitem/get_group_id.h>
4344
#include <spirv/workitem/get_local_id.h>
45+
#include <spirv/workitem/get_local_size.h>
46+
#include <spirv/workitem/get_max_sub_group_size.h>
4447
#include <spirv/workitem/get_num_groups.h>
45-
#include <spirv/workitem/get_group_id.h>
46-
#include <spirv/workitem/get_global_offset.h>
48+
#include <spirv/workitem/get_num_sub_groups.h>
49+
#include <spirv/workitem/get_sub_group_id.h>
50+
#include <spirv/workitem/get_sub_group_local_id.h>
51+
#include <spirv/workitem/get_sub_group_size.h>
4752
#include <spirv/workitem/get_work_dim.h>
4853

4954
/* 6.11.2.1 Floating-point macros */
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupLocalInvocationId();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupSize();

libclc/ptx-nvidiacl/libspirv/SOURCES

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ workitem/get_global_size.cl
7575
workitem/get_group_id.cl
7676
workitem/get_local_id.cl
7777
workitem/get_local_size.cl
78+
workitem/get_max_sub_group_size.cl
7879
workitem/get_num_groups.cl
80+
workitem/get_num_sub_groups.cl
81+
workitem/get_sub_group_id.cl
82+
workitem/get_sub_group_local_id.cl
83+
workitem/get_sub_group_size.cl
7984
images/image_helpers.ll
8085
images/image.cl
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize() {
12+
return 32;
13+
// FIXME: warpsize is defined by NVVM IR but doesn't compile if used here
14+
// return __nvvm_read_ptx_sreg_warpsize();
15+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups() {
12+
// sreg.nwarpid returns number of warp identifiers, not number of warps
13+
// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
14+
size_t size_x = __spirv_WorkgroupSize_x();
15+
size_t size_y = __spirv_WorkgroupSize_y();
16+
size_t size_z = __spirv_WorkgroupSize_z();
17+
uint sg_size = __spirv_SubgroupMaxSize();
18+
uint linear_size = size_z * size_y * size_x;
19+
return (linear_size + sg_size - 1) / sg_size;
20+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId() {
12+
// sreg.warpid is volatile and doesn't represent virtual warp index
13+
// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
14+
size_t id_x = __spirv_LocalInvocationId_x();
15+
size_t id_y = __spirv_LocalInvocationId_y();
16+
size_t id_z = __spirv_LocalInvocationId_z();
17+
size_t size_x = __spirv_WorkgroupSize_x();
18+
size_t size_y = __spirv_WorkgroupSize_y();
19+
size_t size_z = __spirv_WorkgroupSize_z();
20+
uint sg_size = __spirv_SubgroupMaxSize();
21+
return (id_z * size_y * size_x + id_y * size_x + id_x) / sg_size;
22+
}

0 commit comments

Comments
 (0)