Skip to content

Commit 8689420

Browse files
authored
[SYCL] Add non-uniform group classes (#8202)
Implements basic functionality for the following group types: - ballot_group - cluster_group - tangle_group - opportunistic_group This functionality includes all member functions and type traits. Support for group functions and algorithms will follow later. --------- Signed-off-by: John Pennycook <john.pennycook@intel.com>
1 parent 9596ea0 commit 8689420

File tree

8 files changed

+676
-0
lines changed

8 files changed

+676
-0
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,16 @@ __SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
942942
__SYCL_EXPORT __ocl_vec_t<uint32_t, 4>
943943
__spirv_GroupNonUniformBallot(uint32_t Execution, bool Predicate) noexcept;
944944

945+
// TODO: I'm not 100% sure that these NonUniform instructions should be
946+
// convergent Following precedent set for GroupNonUniformBallot above
947+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT uint32_t
948+
__spirv_GroupNonUniformBallotBitCount(__spv::Scope::Flag, int,
949+
__ocl_vec_t<uint32_t, 4>) noexcept;
950+
951+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT int
952+
__spirv_GroupNonUniformBallotFindLSB(__spv::Scope::Flag,
953+
__ocl_vec_t<uint32_t, 4>) noexcept;
954+
945955
extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT void
946956
__clc_BarrierInitialize(int64_t *state, int32_t expected_count) noexcept;
947957

sycl/include/CL/__spirv/spirv_vars.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ __SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups;
6868
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId;
6969
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId;
7070

71+
__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupEqMask;
72+
__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupGeMask;
73+
__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupGtMask;
74+
__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupLeMask;
75+
__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupLtMask;
76+
7177
__DPCPP_SYCL_EXTERNAL inline size_t __spirv_GlobalInvocationId_x() {
7278
return __spirv_BuiltInGlobalInvocationId.x;
7379
}
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
//==------ ballot_group.hpp --- SYCL extension for non-uniform groups ------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp>
11+
#include <sycl/ext/oneapi/sub_group_mask.hpp>
12+
13+
namespace sycl {
14+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
15+
namespace ext::oneapi::experimental {
16+
17+
template <typename ParentGroup> class ballot_group;
18+
19+
template <typename Group>
20+
inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
21+
std::is_same_v<Group, sycl::sub_group>,
22+
ballot_group<Group>>
23+
get_ballot_group(Group group, bool predicate);
24+
25+
template <typename ParentGroup> class ballot_group {
26+
public:
27+
using id_type = id<1>;
28+
using range_type = range<1>;
29+
using linear_id_type = typename ParentGroup::linear_id_type;
30+
static constexpr int dimensions = 1;
31+
static constexpr sycl::memory_scope fence_scope = ParentGroup::fence_scope;
32+
33+
id_type get_group_id() const {
34+
#ifdef __SYCL_DEVICE_ONLY__
35+
return (Predicate) ? 1 : 0;
36+
#else
37+
throw runtime_error("Non-uniform groups are not supported on host device.",
38+
PI_ERROR_INVALID_DEVICE);
39+
#endif
40+
}
41+
42+
id_type get_local_id() const {
43+
#ifdef __SYCL_DEVICE_ONLY__
44+
return detail::CallerPositionInMask(Mask);
45+
#else
46+
throw runtime_error("Non-uniform groups are not supported on host device.",
47+
PI_ERROR_INVALID_DEVICE);
48+
#endif
49+
}
50+
51+
range_type get_group_range() const {
52+
#ifdef __SYCL_DEVICE_ONLY__
53+
return 2;
54+
#else
55+
throw runtime_error("Non-uniform groups are not supported on host device.",
56+
PI_ERROR_INVALID_DEVICE);
57+
#endif
58+
}
59+
60+
range_type get_local_range() const {
61+
#ifdef __SYCL_DEVICE_ONLY__
62+
return Mask.count();
63+
#else
64+
throw runtime_error("Non-uniform groups are not supported on host device.",
65+
PI_ERROR_INVALID_DEVICE);
66+
#endif
67+
}
68+
69+
linear_id_type get_group_linear_id() const {
70+
#ifdef __SYCL_DEVICE_ONLY__
71+
return static_cast<linear_id_type>(get_group_id()[0]);
72+
#else
73+
throw runtime_error("Non-uniform groups are not supported on host device.",
74+
PI_ERROR_INVALID_DEVICE);
75+
#endif
76+
}
77+
78+
linear_id_type get_local_linear_id() const {
79+
#ifdef __SYCL_DEVICE_ONLY__
80+
return static_cast<linear_id_type>(get_local_id()[0]);
81+
#else
82+
throw runtime_error("Non-uniform groups are not supported on host device.",
83+
PI_ERROR_INVALID_DEVICE);
84+
#endif
85+
}
86+
87+
linear_id_type get_group_linear_range() const {
88+
#ifdef __SYCL_DEVICE_ONLY__
89+
return static_cast<linear_id_type>(get_group_range()[0]);
90+
#else
91+
throw runtime_error("Non-uniform groups are not supported on host device.",
92+
PI_ERROR_INVALID_DEVICE);
93+
#endif
94+
}
95+
96+
linear_id_type get_local_linear_range() const {
97+
#ifdef __SYCL_DEVICE_ONLY__
98+
return static_cast<linear_id_type>(get_local_range()[0]);
99+
#else
100+
throw runtime_error("Non-uniform groups are not supported on host device.",
101+
PI_ERROR_INVALID_DEVICE);
102+
#endif
103+
}
104+
105+
bool leader() const {
106+
#ifdef __SYCL_DEVICE_ONLY__
107+
uint32_t Lowest = static_cast<uint32_t>(Mask.find_low()[0]);
108+
return __spirv_SubgroupLocalInvocationId() == Lowest;
109+
#else
110+
throw runtime_error("Non-uniform groups are not supported on host device.",
111+
PI_ERROR_INVALID_DEVICE);
112+
#endif
113+
}
114+
115+
private:
116+
sub_group_mask Mask;
117+
bool Predicate;
118+
119+
protected:
120+
ballot_group(sub_group_mask m, bool p) : Mask(m), Predicate(p) {}
121+
122+
friend ballot_group<ParentGroup>
123+
get_ballot_group<ParentGroup>(ParentGroup g, bool predicate);
124+
};
125+
126+
template <typename Group>
127+
inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
128+
std::is_same_v<Group, sycl::sub_group>,
129+
ballot_group<Group>>
130+
get_ballot_group(Group group, bool predicate) {
131+
(void)group;
132+
#ifdef __SYCL_DEVICE_ONLY__
133+
// ballot_group partitions into two groups using the predicate
134+
// Membership mask for one group is negation of the other
135+
sub_group_mask mask = sycl::ext::oneapi::group_ballot(group, predicate);
136+
if (predicate) {
137+
return ballot_group<sycl::sub_group>(mask, predicate);
138+
} else {
139+
return ballot_group<sycl::sub_group>(~mask, predicate);
140+
}
141+
#else
142+
(void)predicate;
143+
throw runtime_error("Non-uniform groups are not supported on host device.",
144+
PI_ERROR_INVALID_DEVICE);
145+
#endif
146+
}
147+
148+
template <typename ParentGroup>
149+
struct is_user_constructed_group<ballot_group<ParentGroup>> : std::true_type {};
150+
151+
} // namespace ext::oneapi::experimental
152+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
153+
} // namespace sycl
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
//==------ cluster_group.hpp --- SYCL extension for non-uniform groups -----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp>
12+
13+
namespace sycl {
14+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
15+
namespace ext::oneapi::experimental {
16+
17+
template <size_t ClusterSize, typename ParentGroup> class cluster_group;
18+
19+
template <size_t ClusterSize, typename Group>
20+
inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
21+
std::is_same_v<Group, sycl::sub_group>,
22+
cluster_group<ClusterSize, Group>>
23+
get_cluster_group(Group group);
24+
25+
template <size_t ClusterSize, typename ParentGroup> class cluster_group {
26+
public:
27+
using id_type = id<1>;
28+
using range_type = range<1>;
29+
using linear_id_type = typename ParentGroup::linear_id_type;
30+
static constexpr int dimensions = 1;
31+
static constexpr sycl::memory_scope fence_scope = ParentGroup::fence_scope;
32+
33+
id_type get_group_id() const {
34+
#ifdef __SYCL_DEVICE_ONLY__
35+
return __spirv_SubgroupLocalInvocationId() / ClusterSize;
36+
#else
37+
throw runtime_error("Non-uniform groups are not supported on host device.",
38+
PI_ERROR_INVALID_DEVICE);
39+
#endif
40+
}
41+
42+
id_type get_local_id() const {
43+
#ifdef __SYCL_DEVICE_ONLY__
44+
return __spirv_SubgroupLocalInvocationId() % ClusterSize;
45+
#else
46+
throw runtime_error("Non-uniform groups are not supported on host device.",
47+
PI_ERROR_INVALID_DEVICE);
48+
#endif
49+
}
50+
51+
range_type get_group_range() const {
52+
#ifdef __SYCL_DEVICE_ONLY__
53+
return __spirv_SubgroupMaxSize() / ClusterSize;
54+
#else
55+
throw runtime_error("Non-uniform groups are not supported on host device.",
56+
PI_ERROR_INVALID_DEVICE);
57+
#endif
58+
}
59+
60+
range_type get_local_range() const {
61+
#ifdef __SYCL_DEVICE_ONLY__
62+
return ClusterSize;
63+
#else
64+
throw runtime_error("Non-uniform groups are not supported on host device.",
65+
PI_ERROR_INVALID_DEVICE);
66+
#endif
67+
}
68+
69+
linear_id_type get_group_linear_id() const {
70+
#ifdef __SYCL_DEVICE_ONLY__
71+
return static_cast<linear_id_type>(get_group_id()[0]);
72+
#else
73+
throw runtime_error("Non-uniform groups are not supported on host device.",
74+
PI_ERROR_INVALID_DEVICE);
75+
#endif
76+
}
77+
78+
linear_id_type get_local_linear_id() const {
79+
#ifdef __SYCL_DEVICE_ONLY__
80+
return static_cast<linear_id_type>(get_local_id()[0]);
81+
#else
82+
throw runtime_error("Non-uniform groups are not supported on host device.",
83+
PI_ERROR_INVALID_DEVICE);
84+
#endif
85+
}
86+
87+
linear_id_type get_group_linear_range() const {
88+
#ifdef __SYCL_DEVICE_ONLY__
89+
return static_cast<linear_id_type>(get_group_range()[0]);
90+
#else
91+
throw runtime_error("Non-uniform groups are not supported on host device.",
92+
PI_ERROR_INVALID_DEVICE);
93+
#endif
94+
}
95+
96+
linear_id_type get_local_linear_range() const {
97+
#ifdef __SYCL_DEVICE_ONLY__
98+
return static_cast<linear_id_type>(get_local_range()[0]);
99+
#else
100+
throw runtime_error("Non-uniform groups are not supported on host device.",
101+
PI_ERROR_INVALID_DEVICE);
102+
#endif
103+
}
104+
105+
bool leader() const {
106+
#ifdef __SYCL_DEVICE_ONLY__
107+
return get_local_linear_id() == 0;
108+
#else
109+
throw runtime_error("Non-uniform groups are not supported on host device.",
110+
PI_ERROR_INVALID_DEVICE);
111+
#endif
112+
}
113+
114+
protected:
115+
cluster_group() {}
116+
117+
friend cluster_group<ClusterSize, ParentGroup>
118+
get_cluster_group<ClusterSize, ParentGroup>(ParentGroup g);
119+
};
120+
121+
template <size_t ClusterSize, typename Group>
122+
inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
123+
std::is_same_v<Group, sycl::sub_group>,
124+
cluster_group<ClusterSize, Group>>
125+
get_cluster_group(Group group) {
126+
(void)group;
127+
#ifdef __SYCL_DEVICE_ONLY__
128+
return cluster_group<ClusterSize, sycl::sub_group>();
129+
#else
130+
throw runtime_error("Non-uniform groups are not supported on host device.",
131+
PI_ERROR_INVALID_DEVICE);
132+
#endif
133+
}
134+
135+
template <size_t ClusterSize, typename ParentGroup>
136+
struct is_user_constructed_group<cluster_group<ClusterSize, ParentGroup>>
137+
: std::true_type {};
138+
139+
} // namespace ext::oneapi::experimental
140+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
141+
} // namespace sycl
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//==--- non_uniform_groups.hpp --- SYCL extension for non-uniform groups ---==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
#include <CL/__spirv/spirv_ops.hpp>
11+
#include <CL/__spirv/spirv_vars.hpp>
12+
#include <sycl/ext/oneapi/sub_group_mask.hpp>
13+
14+
namespace sycl {
15+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
16+
namespace ext::oneapi::experimental {
17+
18+
template <class T> struct is_fixed_topology_group : std::false_type {};
19+
20+
template <class T>
21+
inline constexpr bool is_fixed_topology_group_v =
22+
is_fixed_topology_group<T>::value;
23+
24+
#ifdef SYCL_EXT_ONEAPI_ROOT_GROUP
25+
template <> struct is_fixed_topology_group<root_group> : std::true_type {};
26+
#endif
27+
28+
template <int Dimensions>
29+
struct is_fixed_topology_group<sycl::group<Dimensions>> : std::true_type {};
30+
31+
template <> struct is_fixed_topology_group<sycl::sub_group> : std::true_type {};
32+
33+
template <class T> struct is_user_constructed_group : std::false_type {};
34+
35+
template <class T>
36+
inline constexpr bool is_user_constructed_group_v =
37+
is_user_constructed_group<T>::value;
38+
39+
#ifdef __SYCL_DEVICE_ONLY__
40+
// TODO: This may need to be generalized beyond uint32_t for big masks
41+
namespace detail {
42+
uint32_t CallerPositionInMask(sub_group_mask Mask) {
43+
// FIXME: It would be nice to be able to jump straight to an __ocl_vec_t
44+
sycl::marray<unsigned, 4> TmpMArray;
45+
Mask.extract_bits(TmpMArray);
46+
sycl::vec<unsigned, 4> MemberMask;
47+
for (int i = 0; i < 4; ++i) {
48+
MemberMask[i] = TmpMArray[i];
49+
}
50+
auto OCLMask =
51+
sycl::detail::ConvertToOpenCLType_t<sycl::vec<unsigned, 4>>(MemberMask);
52+
return __spirv_GroupNonUniformBallotBitCount(
53+
__spv::Scope::Subgroup, (int)__spv::GroupOperation::ExclusiveScan,
54+
OCLMask);
55+
}
56+
} // namespace detail
57+
#endif
58+
59+
} // namespace ext::oneapi::experimental
60+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
61+
} // namespace sycl

0 commit comments

Comments
 (0)